Skip to content

Commit

Permalink
Wrap mask with TreeVec (#702)
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware authored Jul 8, 2024
1 parent ab2322c commit 5ab9ac2
Show file tree
Hide file tree
Showing 11 changed files with 63 additions and 73 deletions.
6 changes: 3 additions & 3 deletions crates/prover/src/core/air/air_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::core::pcs::{CommitmentTreeProver, TreeVec};
use crate::core::poly::circle::SecureCirclePoly;
use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher;
use crate::core::vcs::ops::MerkleOps;
use crate::core::{ColumnVec, ComponentVec, InteractionElements};
use crate::core::{ColumnVec, InteractionElements};

pub trait AirExt: Air {
fn composition_log_degree_bound(&self) -> u32 {
Expand Down Expand Up @@ -55,12 +55,12 @@ pub trait AirExt: Air {
fn eval_composition_polynomial_at_point(
&self,
point: CirclePoint<SecureField>,
mask_values: &ComponentVec<Vec<SecureField>>,
mask_values: &Vec<TreeVec<Vec<Vec<SecureField>>>>,
random_coeff: SecureField,
interaction_elements: &InteractionElements,
) -> SecureField {
let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff);
zip_eq(self.components(), &mask_values.0).for_each(|(component, mask)| {
zip_eq(self.components(), mask_values).for_each(|(component, mask)| {
component.evaluate_constraint_quotients_at_point(
point,
mask,
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub trait Component {
fn evaluate_constraint_quotients_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
mask: &TreeVec<ColumnVec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
interaction_elements: &InteractionElements,
);
Expand Down
7 changes: 7 additions & 0 deletions crates/prover/src/core/pcs/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,19 @@ impl<T> TreeVec<T> {
TreeVec(self.0.into_iter().map(f).collect())
}
pub fn zip<U>(self, other: impl Into<TreeVec<U>>) -> TreeVec<(T, U)> {
let other = other.into();
TreeVec(self.0.into_iter().zip(other.0).collect())
}
pub fn zip_eq<U>(self, other: impl Into<TreeVec<U>>) -> TreeVec<(T, U)> {
let other = other.into();
TreeVec(zip_eq(self.0, other.0).collect())
}
pub fn as_ref(&self) -> TreeVec<&T> {
TreeVec(self.iter().collect())
}
pub fn as_mut(&mut self) -> TreeVec<&mut T> {
TreeVec(self.iter_mut().collect())
}
}

/// Converts `&TreeVec<T>` to `TreeVec<&T>`.
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/core/pcs/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ impl CommitmentSchemeVerifier {
// Verify merkle decommitments.
self.trees
.as_ref()
.zip(proof.decommitments)
.zip(proof.queried_values.clone())
.zip_eq(proof.decommitments)
.zip_eq(proof.queried_values.clone())
.map(|((tree, decommitment), queried_values)| {
let queries = fri_query_domains
.iter()
Expand Down
63 changes: 21 additions & 42 deletions crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher;
use crate::core::vcs::hasher::Hasher;
use crate::core::vcs::ops::MerkleOps;
use crate::core::vcs::verifier::MerkleVerificationError;
use crate::core::ComponentVec;

type Channel = Blake2sChannel;
type ChannelHasher = Blake2sHasher;
Expand Down Expand Up @@ -237,52 +236,32 @@ pub fn verify(
commitment_scheme.verify_values(sample_points, proof.commitment_scheme_proof, channel)
}

#[allow(clippy::type_complexity)]
/// Structures the tree-wise sampled values into component-wise OODS values and a composition
/// polynomial OODS value.
fn sampled_values_to_mask(
air: &impl Air,
sampled_values: &TreeVec<ColumnVec<Vec<SecureField>>>,
) -> Result<(ComponentVec<Vec<SecureField>>, SecureField), InvalidOodsSampleStructure> {
// Retrieve sampled mask values for each component.
let flat_trace_values = &mut sampled_values
.first()
.ok_or(InvalidOodsSampleStructure)?
.iter();
let mut trace_oods_values = vec![];
air.components().iter().for_each(|component| {
let n_trace_points = component.mask_points(CirclePoint::zero())[0].len();
trace_oods_values.push(
flat_trace_values
.take(n_trace_points)
.cloned()
.collect_vec(),
)
});

if air.n_interaction_phases() == 2 {
let interaction_values = &mut sampled_values
.get(1)
.ok_or(InvalidOodsSampleStructure)?
.iter();

air.components()
.iter()
.zip_eq(&mut trace_oods_values)
.for_each(|(component, values)| {
let n_interaction_points = component.mask_points(CirclePoint::zero())[1].len();
values.extend(
interaction_values
.take(n_interaction_points)
.cloned()
.collect_vec(),
)
});
}
) -> Result<(Vec<TreeVec<Vec<Vec<SecureField>>>>, SecureField), InvalidOodsSampleStructure> {
let mut sampled_values = sampled_values.as_ref();
let composition_values = sampled_values.pop().ok_or(InvalidOodsSampleStructure)?;

let mut sample_iters = sampled_values.map(|tree_value| tree_value.iter());
let trace_oods_values = air
.components()
.iter()
.map(|component| {
component
.mask_points(CirclePoint::zero())
.zip(sample_iters.as_mut())
.map(|(mask_per_tree, tree_iter)| {
tree_iter.take(mask_per_tree.len()).cloned().collect_vec()
})
})
.collect_vec();

let composition_partial_sampled_values =
sampled_values.last().ok_or(InvalidOodsSampleStructure)?;
let composition_oods_value = SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(
composition_partial_sampled_values
composition_values
.iter()
.flatten()
.cloned()
Expand All @@ -291,7 +270,7 @@ fn sampled_values_to_mask(
.map_err(|_| InvalidOodsSampleStructure)?,
);

Ok((ComponentVec(trace_oods_values), composition_oods_value))
Ok((trace_oods_values, composition_oods_value))
}

/// Error when the sampled values have an invalid structure.
Expand Down Expand Up @@ -428,7 +407,7 @@ mod tests {
fn evaluate_constraint_quotients_at_point(
&self,
_point: CirclePoint<SecureField>,
_mask: &crate::core::ColumnVec<Vec<SecureField>>,
_mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
) {
Expand Down
14 changes: 6 additions & 8 deletions crates/prover/src/examples/fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,17 @@ impl Component for FibonacciComponent {
fn evaluate_constraint_quotients_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
mask: &TreeVec<ColumnVec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
) {
evaluation_accumulator.accumulate(
self.step_constraint_eval_quotient_by_mask(point, &mask[0][..].try_into().unwrap()),
);
evaluation_accumulator.accumulate(
self.boundary_constraint_eval_quotient_by_mask(
point,
&mask[0][..1].try_into().unwrap(),
),
self.step_constraint_eval_quotient_by_mask(point, &mask[0][0][..].try_into().unwrap()),
);
evaluation_accumulator.accumulate(self.boundary_constraint_eval_quotient_by_mask(
point,
&mask[0][0][..1].try_into().unwrap(),
));
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ mod tests {
let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff);
fib.air.component.evaluate_constraint_quotients_at_point(
point,
&mask_values,
&TreeVec::new(vec![mask_values]),
&mut evaluation_accumulator,
&InteractionElements::new(BTreeMap::new()),
);
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/examples/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,15 @@ impl Component for PoseidonComponent {
fn evaluate_constraint_quotients_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
) {
let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset;
let denom = coset_vanishing(constraint_zero_domain, point);
let denom_inverse = denom.inverse();
let mut eval = PoseidonEvalAtPoint {
mask,
mask: &mask[0],
evaluation_accumulator,
col_index: 0,
denom_inverse,
Expand Down
28 changes: 17 additions & 11 deletions crates/prover/src/examples/wide_fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,51 +66,57 @@ impl WideFibComponent {
fn evaluate_lookup_boundary_constraint_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
constraint_zero_domain: Coset,
interaction_elements: &InteractionElements,
) {
let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]);
let value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|i| {
mask[self.n_columns() + i][0]
mask[1][i][0]
}));
let numerator = (value
* shifted_secure_combination(
&[mask[self.n_columns() - 2][0], mask[self.n_columns() - 1][0]],
&[
mask[0][self.n_columns() - 2][0],
mask[0][self.n_columns() - 1][0],
],
alpha,
z,
))
- shifted_secure_combination(&[mask[0][0], mask[1][0]], alpha, z);
- shifted_secure_combination(&[mask[0][0][0], mask[0][1][0]], alpha, z);
let denom = point_vanishing(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);
}

fn evaluate_lookup_step_constraints_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
constraint_zero_domain: Coset,
interaction_elements: &InteractionElements,
) {
let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]);
let value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|i| {
mask[self.n_columns() + i][0]
mask[1][i][0]
}));
let prev_value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|i| {
mask[self.n_columns() + i][1]
mask[1][i][1]
}));
let numerator = (value
* shifted_secure_combination(
&[mask[self.n_columns() - 2][0], mask[self.n_columns() - 1][0]],
&[
mask[0][self.n_columns() - 2][0],
mask[0][self.n_columns() - 1][0],
],
alpha,
z,
))
- (prev_value * shifted_secure_combination(&[mask[0][0], mask[1][0]], alpha, z));
- (prev_value * shifted_secure_combination(&[mask[0][0][0], mask[0][1][0]], alpha, z));
let denom = coset_vanishing(constraint_zero_domain, point)
/ point_excluder(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);
Expand Down Expand Up @@ -165,7 +171,7 @@ impl Component for WideFibComponent {
fn evaluate_constraint_quotients_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
interaction_elements: &InteractionElements,
) {
Expand All @@ -186,7 +192,7 @@ impl Component for WideFibComponent {
);
self.evaluate_trace_step_constraints_at_point(
point,
mask,
&mask[0],
evaluation_accumulator,
constraint_zero_domain,
);
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/examples/wide_fibonacci/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,15 @@ impl Component for SimdWideFibComponent {
fn evaluate_constraint_quotients_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
) {
let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset;
let denom = coset_vanishing(constraint_zero_domain, point);
let denom_inverse = denom.inverse();
for i in 0..self.n_columns() - 2 {
let numerator = mask[i][0].square() + mask[i + 1][0].square() - mask[i + 2][0];
let numerator = mask[0][i][0].square() + mask[0][i + 1][0].square() - mask[0][i + 2][0];
evaluation_accumulator.accumulate(numerator * denom_inverse);
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/trace_generation/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ mod tests {
fn evaluate_constraint_quotients_at_point(
&self,
_point: CirclePoint<SecureField>,
_mask: &ColumnVec<Vec<SecureField>>,
_mask: &TreeVec<Vec<Vec<SecureField>>>,
_evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
) {
Expand Down

0 comments on commit 5ab9ac2

Please sign in to comment.