Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap mask with TreeVec #702

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions crates/prover/src/core/air/air_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::core::pcs::{CommitmentTreeProver, TreeVec};
use crate::core::poly::circle::SecureCirclePoly;
use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher;
use crate::core::vcs::ops::MerkleOps;
use crate::core::{ColumnVec, ComponentVec, InteractionElements};
use crate::core::{ColumnVec, InteractionElements};

pub trait AirExt: Air {
fn composition_log_degree_bound(&self) -> u32 {
Expand Down Expand Up @@ -55,12 +55,12 @@ pub trait AirExt: Air {
fn eval_composition_polynomial_at_point(
&self,
point: CirclePoint<SecureField>,
mask_values: &ComponentVec<Vec<SecureField>>,
mask_values: &Vec<TreeVec<Vec<Vec<SecureField>>>>,
random_coeff: SecureField,
interaction_elements: &InteractionElements,
) -> SecureField {
let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff);
zip_eq(self.components(), &mask_values.0).for_each(|(component, mask)| {
zip_eq(self.components(), mask_values).for_each(|(component, mask)| {
component.evaluate_constraint_quotients_at_point(
point,
mask,
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub trait Component {
fn evaluate_constraint_quotients_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
mask: &TreeVec<ColumnVec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
interaction_elements: &InteractionElements,
);
Expand Down
7 changes: 7 additions & 0 deletions crates/prover/src/core/pcs/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,19 @@ impl<T> TreeVec<T> {
TreeVec(self.0.into_iter().map(f).collect())
}
pub fn zip<U>(self, other: impl Into<TreeVec<U>>) -> TreeVec<(T, U)> {
let other = other.into();
TreeVec(self.0.into_iter().zip(other.0).collect())
}
pub fn zip_eq<U>(self, other: impl Into<TreeVec<U>>) -> TreeVec<(T, U)> {
let other = other.into();
TreeVec(zip_eq(self.0, other.0).collect())
}
pub fn as_ref(&self) -> TreeVec<&T> {
TreeVec(self.iter().collect())
}
pub fn as_mut(&mut self) -> TreeVec<&mut T> {
TreeVec(self.iter_mut().collect())
}
}

/// Converts `&TreeVec<T>` to `TreeVec<&T>`.
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/core/pcs/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ impl CommitmentSchemeVerifier {
// Verify merkle decommitments.
self.trees
.as_ref()
.zip(proof.decommitments)
.zip(proof.queried_values.clone())
.zip_eq(proof.decommitments)
.zip_eq(proof.queried_values.clone())
.map(|((tree, decommitment), queried_values)| {
let queries = fri_query_domains
.iter()
Expand Down
63 changes: 21 additions & 42 deletions crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher;
use crate::core::vcs::hasher::Hasher;
use crate::core::vcs::ops::MerkleOps;
use crate::core::vcs::verifier::MerkleVerificationError;
use crate::core::ComponentVec;

type Channel = Blake2sChannel;
type ChannelHasher = Blake2sHasher;
Expand Down Expand Up @@ -237,52 +236,32 @@ pub fn verify(
commitment_scheme.verify_values(sample_points, proof.commitment_scheme_proof, channel)
}

#[allow(clippy::type_complexity)]
/// Structures the tree-wise sampled values into component-wise OODS values and a composition
/// polynomial OODS value.
fn sampled_values_to_mask(
air: &impl Air,
sampled_values: &TreeVec<ColumnVec<Vec<SecureField>>>,
) -> Result<(ComponentVec<Vec<SecureField>>, SecureField), InvalidOodsSampleStructure> {
// Retrieve sampled mask values for each component.
let flat_trace_values = &mut sampled_values
.first()
.ok_or(InvalidOodsSampleStructure)?
.iter();
let mut trace_oods_values = vec![];
air.components().iter().for_each(|component| {
let n_trace_points = component.mask_points(CirclePoint::zero())[0].len();
trace_oods_values.push(
flat_trace_values
.take(n_trace_points)
.cloned()
.collect_vec(),
)
});

if air.n_interaction_phases() == 2 {
let interaction_values = &mut sampled_values
.get(1)
.ok_or(InvalidOodsSampleStructure)?
.iter();

air.components()
.iter()
.zip_eq(&mut trace_oods_values)
.for_each(|(component, values)| {
let n_interaction_points = component.mask_points(CirclePoint::zero())[1].len();
values.extend(
interaction_values
.take(n_interaction_points)
.cloned()
.collect_vec(),
)
});
}
) -> Result<(Vec<TreeVec<Vec<Vec<SecureField>>>>, SecureField), InvalidOodsSampleStructure> {
let mut sampled_values = sampled_values.as_ref();
let composition_values = sampled_values.pop().ok_or(InvalidOodsSampleStructure)?;

let mut sample_iters = sampled_values.map(|tree_value| tree_value.iter());
let trace_oods_values = air
.components()
.iter()
.map(|component| {
component
.mask_points(CirclePoint::zero())
.zip(sample_iters.as_mut())
.map(|(mask_per_tree, tree_iter)| {
tree_iter.take(mask_per_tree.len()).cloned().collect_vec()
})
})
.collect_vec();

let composition_partial_sampled_values =
sampled_values.last().ok_or(InvalidOodsSampleStructure)?;
let composition_oods_value = SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(
composition_partial_sampled_values
composition_values
.iter()
.flatten()
.cloned()
Expand All @@ -291,7 +270,7 @@ fn sampled_values_to_mask(
.map_err(|_| InvalidOodsSampleStructure)?,
);

Ok((ComponentVec(trace_oods_values), composition_oods_value))
Ok((trace_oods_values, composition_oods_value))
}

/// Error when the sampled values have an invalid structure.
Expand Down Expand Up @@ -428,7 +407,7 @@ mod tests {
fn evaluate_constraint_quotients_at_point(
&self,
_point: CirclePoint<SecureField>,
_mask: &crate::core::ColumnVec<Vec<SecureField>>,
_mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
) {
Expand Down
14 changes: 6 additions & 8 deletions crates/prover/src/examples/fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,17 @@ impl Component for FibonacciComponent {
fn evaluate_constraint_quotients_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
mask: &TreeVec<ColumnVec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
) {
evaluation_accumulator.accumulate(
self.step_constraint_eval_quotient_by_mask(point, &mask[0][..].try_into().unwrap()),
);
evaluation_accumulator.accumulate(
self.boundary_constraint_eval_quotient_by_mask(
point,
&mask[0][..1].try_into().unwrap(),
),
self.step_constraint_eval_quotient_by_mask(point, &mask[0][0][..].try_into().unwrap()),
);
evaluation_accumulator.accumulate(self.boundary_constraint_eval_quotient_by_mask(
point,
&mask[0][0][..1].try_into().unwrap(),
));
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ mod tests {
let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff);
fib.air.component.evaluate_constraint_quotients_at_point(
point,
&mask_values,
&TreeVec::new(vec![mask_values]),
&mut evaluation_accumulator,
&InteractionElements::new(BTreeMap::new()),
);
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/examples/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,15 @@ impl Component for PoseidonComponent {
fn evaluate_constraint_quotients_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
) {
let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset;
let denom = coset_vanishing(constraint_zero_domain, point);
let denom_inverse = denom.inverse();
let mut eval = PoseidonEvalAtPoint {
mask,
mask: &mask[0],
evaluation_accumulator,
col_index: 0,
denom_inverse,
Expand Down
28 changes: 17 additions & 11 deletions crates/prover/src/examples/wide_fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,51 +66,57 @@ impl WideFibComponent {
fn evaluate_lookup_boundary_constraint_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
constraint_zero_domain: Coset,
interaction_elements: &InteractionElements,
) {
let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]);
let value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|i| {
mask[self.n_columns() + i][0]
mask[1][i][0]
}));
let numerator = (value
* shifted_secure_combination(
&[mask[self.n_columns() - 2][0], mask[self.n_columns() - 1][0]],
&[
mask[0][self.n_columns() - 2][0],
mask[0][self.n_columns() - 1][0],
],
alpha,
z,
))
- shifted_secure_combination(&[mask[0][0], mask[1][0]], alpha, z);
- shifted_secure_combination(&[mask[0][0][0], mask[0][1][0]], alpha, z);
let denom = point_vanishing(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);
}

fn evaluate_lookup_step_constraints_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
constraint_zero_domain: Coset,
interaction_elements: &InteractionElements,
) {
let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]);
let value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|i| {
mask[self.n_columns() + i][0]
mask[1][i][0]
}));
let prev_value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|i| {
mask[self.n_columns() + i][1]
mask[1][i][1]
}));
let numerator = (value
* shifted_secure_combination(
&[mask[self.n_columns() - 2][0], mask[self.n_columns() - 1][0]],
&[
mask[0][self.n_columns() - 2][0],
mask[0][self.n_columns() - 1][0],
],
alpha,
z,
))
- (prev_value * shifted_secure_combination(&[mask[0][0], mask[1][0]], alpha, z));
- (prev_value * shifted_secure_combination(&[mask[0][0][0], mask[0][1][0]], alpha, z));
let denom = coset_vanishing(constraint_zero_domain, point)
/ point_excluder(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);
Expand Down Expand Up @@ -165,7 +171,7 @@ impl Component for WideFibComponent {
fn evaluate_constraint_quotients_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
interaction_elements: &InteractionElements,
) {
Expand All @@ -186,7 +192,7 @@ impl Component for WideFibComponent {
);
self.evaluate_trace_step_constraints_at_point(
point,
mask,
&mask[0],
evaluation_accumulator,
constraint_zero_domain,
);
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/examples/wide_fibonacci/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,15 @@ impl Component for SimdWideFibComponent {
fn evaluate_constraint_quotients_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
) {
let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset;
let denom = coset_vanishing(constraint_zero_domain, point);
let denom_inverse = denom.inverse();
for i in 0..self.n_columns() - 2 {
let numerator = mask[i][0].square() + mask[i + 1][0].square() - mask[i + 2][0];
let numerator = mask[0][i][0].square() + mask[0][i + 1][0].square() - mask[0][i + 2][0];
evaluation_accumulator.accumulate(numerator * denom_inverse);
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/trace_generation/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ mod tests {
fn evaluate_constraint_quotients_at_point(
&self,
_point: CirclePoint<SecureField>,
_mask: &ColumnVec<Vec<SecureField>>,
_mask: &TreeVec<Vec<Vec<SecureField>>>,
_evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
) {
Expand Down
Loading