Skip to content

Commit

Permalink
Use SecureField in interaction.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed May 27, 2024
1 parent 43176a4 commit f54597b
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 70 deletions.
34 changes: 21 additions & 13 deletions crates/prover/src/core/air/air_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::pcs::{CommitmentTreeProver, TreeVec};
use crate::core::poly::circle::{CircleEvaluation, SecureCirclePoly};
use crate::core::poly::circle::{CircleEvaluation, CirclePoly, SecureCirclePoly};
use crate::core::poly::BitReversedOrder;
use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher;
use crate::core::vcs::ops::MerkleOps;
Expand Down Expand Up @@ -51,7 +51,7 @@ pub trait AirExt: Air {
for component in self.components() {
ids.extend(component.interaction_element_ids());
}
let elements = channel.draw_felts(ids.len()).into_iter().map(|e| e.0 .0);
let elements = channel.draw_felts(ids.len());
InteractionElements(BTreeMap::from_iter(zip_eq(ids, elements)))
}

Expand Down Expand Up @@ -132,18 +132,26 @@ pub trait AirProverExt<B: Backend>: AirProver<B> {
&self,
trace: &ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> ComponentVec<CircleEvaluation<B, BaseField, BitReversedOrder>> {
) -> Vec<CirclePoly<B>> {
let trace_iter = &mut trace.iter();
ComponentVec(
self.prover_components()
.iter()
.map(|component| {
let n_columns = component.trace_log_degree_bounds()[0].len();
let trace_columns = trace_iter.take(n_columns).collect_vec();
component.write_interaction_trace(&trace_columns, elements)
})
.collect(),
)

self.prover_components()
.iter()
.flat_map(|component| {
let n_columns = component.trace_log_degree_bounds()[0].len();
let trace_columns = trace_iter.take(n_columns).collect_vec();
component
.write_interaction_trace(&trace_columns, elements)
.into_iter()
.flat_map(|eval| {
eval.values.columns.map(|c| {
CircleEvaluation::<B, BaseField, BitReversedOrder>::new(eval.domain, c)
.interpolate()
})
})
.collect_vec()
})
.collect()
}

fn compute_composition_polynomial(
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/core/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::circle::CirclePoint;
use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
use super::pcs::TreeVec;
use super::poly::circle::{CircleEvaluation, CirclePoly};
use super::poly::circle::{CircleEvaluation, CirclePoly, SecureEvaluation};
use super::poly::BitReversedOrder;
use super::{ColumnVec, InteractionElements};

Expand Down Expand Up @@ -60,7 +60,7 @@ pub trait ComponentTraceWriter<B: Backend> {
&self,
trace: &ColumnVec<&CircleEvaluation<B, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>;
) -> ColumnVec<SecureEvaluation<B>>;
}

pub trait ComponentProver<B: Backend>: Component + ComponentTraceWriter<B> {
Expand Down
8 changes: 4 additions & 4 deletions crates/prover/src/core/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::BTreeMap;
use std::ops::{Deref, DerefMut, Index};

use self::fields::m31::BaseField;
use self::fields::qm31::SecureField;

pub mod air;
pub mod backend;
Expand Down Expand Up @@ -61,10 +61,10 @@ impl<T> DerefMut for ComponentVec<T> {
}
}

pub struct InteractionElements(BTreeMap<String, BaseField>);
pub struct InteractionElements(BTreeMap<String, SecureField>);

impl InteractionElements {
pub fn new(elements: BTreeMap<String, BaseField>) -> Self {
pub fn new(elements: BTreeMap<String, SecureField>) -> Self {
Self(elements)
}

Expand All @@ -74,7 +74,7 @@ impl InteractionElements {
}

impl Index<&str> for InteractionElements {
type Output = BaseField;
type Output = SecureField;

fn index(&self, index: &str) -> &Self::Output {
// TODO(AlonH): Return an error if the key is not found.
Expand Down
22 changes: 7 additions & 15 deletions crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use tracing::{span, Level};

use super::air::AirProver;
use super::backend::Backend;
use super::fields::secure_column::SECURE_EXTENSION_DEGREE;
use super::fri::FriVerificationError;
use super::pcs::{CommitmentSchemeProof, TreeVec};
use super::poly::circle::{CanonicCoset, SecureCirclePoly, MAX_CIRCLE_DOMAIN_LOG_SIZE};
Expand Down Expand Up @@ -69,16 +70,7 @@ pub fn evaluate_and_commit_on_trace<B: Backend + MerkleOps<MerkleHasher>>(
span.exit();

let interaction_elements = air.interaction_elements(channel);
let interaction_traces = air.interact(&trace, &interaction_elements);
let interaction_trace_polys = interaction_traces
.0
.into_iter()
.flat_map(|trace| {
trace
.into_iter()
.map(|poly| poly.interpolate_with_twiddles(twiddles))
})
.collect_vec();
let interaction_trace_polys = air.interact(&trace, &interaction_elements);
let n_interaction_traces = interaction_trace_polys.len();
if n_interaction_traces > 0 {
commitment_scheme.commit(interaction_trace_polys, channel, twiddles);
Expand Down Expand Up @@ -116,7 +108,7 @@ pub fn generate_proof<B: Backend + MerkleOps<MerkleHasher>>(
let mut sample_points = air.mask_points(oods_point);

// Get composition polynomial sample points.
sample_points.push(vec![vec![oods_point]; 4]);
sample_points.push(vec![vec![oods_point]; SECURE_EXTENSION_DEGREE]);

// Prove the trace and composition OODS values, and retrieve them.
let commitment_scheme_proof = commitment_scheme.prove_values(sample_points, channel, twiddles);
Expand Down Expand Up @@ -207,7 +199,7 @@ pub fn verify(
// Read composition polynomial commitment.
commitment_scheme.commit(
*proof.commitments.last().unwrap(),
&[air.composition_log_degree_bound(); 4],
&[air.composition_log_degree_bound(); SECURE_EXTENSION_DEGREE],
channel,
);

Expand All @@ -218,7 +210,7 @@ pub fn verify(
let mut sample_points = air.mask_points(oods_point);

// Get composition polynomial sample points.
sample_points.push(vec![vec![oods_point]; 4]);
sample_points.push(vec![vec![oods_point]; SECURE_EXTENSION_DEGREE]);

// TODO(spapini): Save clone.
let (trace_oods_values, composition_oods_value) = sampled_values_to_mask(
Expand Down Expand Up @@ -350,7 +342,7 @@ mod tests {
use crate::core::fields::qm31::SecureField;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, MAX_CIRCLE_DOMAIN_LOG_SIZE,
CanonicCoset, CircleDomain, CircleEvaluation, SecureEvaluation, MAX_CIRCLE_DOMAIN_LOG_SIZE,
};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::{prove, ProvingError};
Expand Down Expand Up @@ -419,7 +411,7 @@ mod tests {
&self,
_trace: &ColumnVec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
) -> ColumnVec<SecureEvaluation<CpuBackend>> {
vec![]
}
}
Expand Down
14 changes: 9 additions & 5 deletions crates/prover/src/core/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::iter::Peekable;
use std::ops::Add;

use num_traits::One;
use num_traits::{One, Zero};

use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
Expand Down Expand Up @@ -92,12 +93,15 @@ pub fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec<SecureF
/// Alpha and z should be secure field elements for soundness.
pub fn shifted_secure_combination<F: ExtensionOf<BaseField>>(
values: &[F],
alpha: BaseField,
z: BaseField,
) -> F {
alpha: SecureField,
z: SecureField,
) -> SecureField
where
SecureField: Add<F, Output = SecureField>,
{
let res = values
.iter()
.fold(F::zero(), |acc, &value| acc * alpha + value);
.fold(SecureField::zero(), |acc, &value| acc * alpha + value);
res - z
}

Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/examples/fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{ExtensionOf, FieldExpOps};
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::bit_reverse_index;
use crate::core::{ColumnVec, InteractionElements};
Expand Down Expand Up @@ -127,7 +127,7 @@ impl ComponentTraceWriter<CpuBackend> for FibonacciComponent {
&self,
_trace: &ColumnVec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
) -> ColumnVec<SecureEvaluation<CpuBackend>> {
vec![]
}
}
Expand Down
26 changes: 20 additions & 6 deletions crates/prover/src/examples/wide_fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ use crate::core::circle::CirclePoint;
use crate::core::constraints::{coset_vanishing, point_vanishing};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE};
use crate::core::fields::FieldExpOps;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::circle::{
CanonicCoset, CircleEvaluation, SecureCirclePoly, SecureEvaluation,
};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::shifted_secure_combination;
use crate::core::{ColumnVec, InteractionElements};
Expand Down Expand Up @@ -69,7 +72,7 @@ impl Component for WideFibComponent {
fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
TreeVec::new(vec![
vec![self.log_column_size(); self.n_columns()],
vec![self.log_column_size(); 1],
vec![self.log_column_size(); SECURE_EXTENSION_DEGREE],
])
}

Expand All @@ -79,7 +82,7 @@ impl Component for WideFibComponent {
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
TreeVec::new(vec![
fixed_mask_points(&vec![vec![0_usize]; self.n_columns()], point),
vec![vec![point]],
vec![vec![point]; SECURE_EXTENSION_DEGREE],
])
}

Expand All @@ -96,7 +99,11 @@ impl Component for WideFibComponent {
) {
let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset;
let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]);
let lookup_numerator = (mask[self.n_columns()][0]
let lookup_value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|i| {
mask[self.n_columns() + i][0]
}));
let lookup_numerator = (lookup_value
* shifted_secure_combination(
&[mask[self.n_columns() - 2][0], mask[self.n_columns() - 1][0]],
alpha,
Expand All @@ -120,12 +127,19 @@ impl ComponentTraceWriter<CpuBackend> for WideFibComponent {
&self,
trace: &ColumnVec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
) -> ColumnVec<SecureEvaluation<CpuBackend>> {
let interaction_trace_domain = trace[0].domain;
let trace_values = trace.iter().map(|eval| &eval.values[..]).collect_vec();
let (alpha, z) = (elements[ALPHA_ID], elements[Z_ID]);
let values = write_lookup_column(&trace_values, alpha, z);
let eval = CircleEvaluation::new(interaction_trace_domain, values);
let mut secure_column = SecureColumn::<CpuBackend>::zeros(values.len());
for (i, value) in values.into_iter().enumerate() {
secure_column.set(i, value);
}
let eval = SecureEvaluation {
domain: interaction_trace_domain,
values: secure_column,
};
vec![eval]
}
}
Expand Down
8 changes: 6 additions & 2 deletions crates/prover/src/examples/wide_fibonacci/constraint_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::core::constraints::{coset_vanishing, point_vanishing};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::CanonicCoset;
use crate::core::poly::circle::{CanonicCoset, SecureCirclePoly};
use crate::core::utils::{bit_reverse, shifted_secure_combination};
use crate::core::{ColumnVec, InteractionElements};
use crate::examples::wide_fibonacci::component::LOG_N_COLUMNS;
Expand Down Expand Up @@ -61,8 +61,12 @@ impl ComponentProver<CpuBackend> for WideFibComponent {
}

// Lookup constraints.
let lookup_value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|j| {
trace_evals[1][j][i].into()
}));
lookup_numerators[i] = accum.random_coeff_powers[self.n_columns() - 2]
* ((trace_evals[1][0][i]
* ((lookup_value
* shifted_secure_combination(
&[
trace_evals[0][self.n_columns() - 2][i],
Expand Down
33 changes: 19 additions & 14 deletions crates/prover/src/examples/wide_fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ mod tests {
use crate::core::backend::CpuBackend;
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::IntoSlice;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::CanonicCoset;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::{prove, verify};
use crate::core::utils::shifted_secure_combination;
Expand All @@ -40,15 +41,15 @@ mod tests {
}

pub fn assert_constraints_on_lookup_column(
column: &[BaseField],
column: &[SecureField],
input_trace: &[Vec<BaseField>],
alpha: BaseField,
z: BaseField,
alpha: SecureField,
z: SecureField,
) {
let n_columns = input_trace.len();
let column_length = column.len();
assert_eq!(column_length, input_trace[0].len());
let mut prev_value = BaseField::one();
let mut prev_value = SecureField::one();
for (i, cell) in column.iter().enumerate() {
assert_eq!(
*cell
Expand Down Expand Up @@ -110,8 +111,8 @@ mod tests {
b: m31!(1),
};

let alpha = m31!(7);
let z = m31!(11);
let alpha = qm31!(7, 1, 3, 4);
let z = qm31!(11, 1, 2, 3);
let trace = gen_trace(&wide_fib, vec![input]);
let input_trace = trace.iter().map(|values| &values[..]).collect_vec();
let lookup_column = write_lookup_column(&input_trace, alpha, z);
Expand Down Expand Up @@ -163,15 +164,19 @@ mod tests {
.iter()
.cloned()
.enumerate()
.map(|(i, id)| (id, m31!(43 + i as u32))),
.map(|(i, id)| (id, qm31!(43 + i as u32, 1, 2, 3))),
));
let interaction_trace =
wide_fib.write_interaction_trace(&trace.iter().collect(), &interaction_elements);

let interaction_poly = interaction_trace
.iter()
.map(|trace| trace.clone().interpolate())
let interaction_poly = wide_fib
.write_interaction_trace(&trace.iter().collect(), &interaction_elements)
.into_iter()
.flat_map(|eval| {
eval.values.columns.map(|c| {
CircleEvaluation::<CpuBackend, BaseField, BitReversedOrder>::new(eval.domain, c)
.interpolate()
})
})
.collect_vec();

let interaction_trace = interaction_poly
.iter()
.map(|poly| poly.evaluate(eval_domain))
Expand Down
Loading

0 comments on commit f54597b

Please sign in to comment.