Skip to content

Commit

Permalink
Use SecureField in interaction.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed Jun 23, 2024
1 parent 60c196d commit 88f8a6d
Show file tree
Hide file tree
Showing 11 changed files with 92 additions and 72 deletions.
6 changes: 3 additions & 3 deletions crates/prover/src/core/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::fields::qm31::SecureField;
use super::pcs::TreeVec;
use super::poly::circle::{CircleEvaluation, CirclePoly};
use super::poly::BitReversedOrder;
use super::{ColumnVec, ComponentVec, InteractionElements};
use super::{ColumnVec, InteractionElements};

pub mod accumulation;
mod air_ext;
Expand All @@ -34,7 +34,7 @@ pub trait AirTraceWriter<B: Backend>: AirTraceVerifier {
&self,
trace: &ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> ComponentVec<CircleEvaluation<B, BaseField, BitReversedOrder>>;
) -> Vec<CircleEvaluation<B, BaseField, BitReversedOrder>>;

fn to_air_prover(&self) -> &impl AirProver<B>;
}
Expand Down Expand Up @@ -82,7 +82,7 @@ pub trait ComponentTraceWriter<B: Backend> {
&self,
trace: &ColumnVec<&CircleEvaluation<B, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>;
) -> Vec<CircleEvaluation<B, BaseField, BitReversedOrder>>;
}

pub trait ComponentProver<B: Backend>: Component {
Expand Down
8 changes: 4 additions & 4 deletions crates/prover/src/core/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::BTreeMap;
use std::ops::{Deref, DerefMut, Index};

use self::fields::m31::BaseField;
use self::fields::qm31::SecureField;

pub mod air;
pub mod backend;
Expand Down Expand Up @@ -62,10 +62,10 @@ impl<T> DerefMut for ComponentVec<T> {
}

#[derive(Default)]
pub struct InteractionElements(BTreeMap<String, BaseField>);
pub struct InteractionElements(BTreeMap<String, SecureField>);

impl InteractionElements {
pub fn new(elements: BTreeMap<String, BaseField>) -> Self {
pub fn new(elements: BTreeMap<String, SecureField>) -> Self {
Self(elements)
}

Expand All @@ -75,7 +75,7 @@ impl InteractionElements {
}

impl Index<&str> for InteractionElements {
type Output = BaseField;
type Output = SecureField;

fn index(&self, index: &str) -> &Self::Output {
// TODO(AlonH): Return an error if the key is not found.
Expand Down
29 changes: 13 additions & 16 deletions crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub fn evaluate_and_commit_on_trace<B: Backend + MerkleOps<MerkleHasher>>(
let trace_polys = trace
.clone()
.into_iter()
.map(|poly| poly.interpolate_with_twiddles(twiddles))
.map(|eval| eval.interpolate_with_twiddles(twiddles))
.collect();
span.exit();

Expand All @@ -71,17 +71,14 @@ pub fn evaluate_and_commit_on_trace<B: Backend + MerkleOps<MerkleHasher>>(
span.exit();

let interaction_elements = air.interaction_elements(channel);
let interaction_traces = air.interact(&trace, &interaction_elements);
let interaction_trace_polys = interaction_traces
.0
.into_iter()
.flat_map(|trace| {
trace
.into_iter()
.map(|poly| poly.interpolate_with_twiddles(twiddles))
})
.collect_vec();
if !interaction_trace_polys.is_empty() {
let interaction_trace = air.interact(&trace, &interaction_elements);
if !interaction_trace.is_empty() {
let span = span!(Level::INFO, "Interaction trace interpolation").entered();
let interaction_trace_polys = interaction_trace
.into_iter()
.map(|eval| eval.interpolate_with_twiddles(twiddles))
.collect();
span.exit();
commitment_scheme.commit(interaction_trace_polys, channel, twiddles);
}

Expand Down Expand Up @@ -356,7 +353,7 @@ mod tests {
use crate::core::poly::BitReversedOrder;
use crate::core::prover::{prove, ProvingError};
use crate::core::test_utils::test_channel;
use crate::core::{ColumnVec, ComponentVec, InteractionElements};
use crate::core::{ColumnVec, InteractionElements};
use crate::qm31;

struct TestAir<C: ComponentProver<CpuBackend>> {
Expand All @@ -380,8 +377,8 @@ mod tests {
&self,
_trace: &ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ComponentVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
ComponentVec(vec![vec![]])
) -> Vec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
vec![]
}

fn to_air_prover(&self) -> &impl AirProver<CpuBackend> {
Expand Down Expand Up @@ -444,7 +441,7 @@ mod tests {
&self,
_trace: &ColumnVec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
) -> Vec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
vec![]
}
}
Expand Down
14 changes: 9 additions & 5 deletions crates/prover/src/core/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::iter::Peekable;
use std::ops::Add;

use num_traits::One;
use num_traits::{One, Zero};

use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
Expand Down Expand Up @@ -92,12 +93,15 @@ pub fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec<SecureF
/// Alpha and z should be secure field elements for soundness.
pub fn shifted_secure_combination<F: ExtensionOf<BaseField>>(
values: &[F],
alpha: BaseField,
z: BaseField,
) -> F {
alpha: SecureField,
z: SecureField,
) -> SecureField
where
SecureField: Add<F, Output = SecureField>,
{
let res = values
.iter()
.fold(F::zero(), |acc, &value| acc * alpha + value);
.fold(SecureField::zero(), |acc, &value| acc * alpha + value);
res - z
}

Expand Down
10 changes: 5 additions & 5 deletions crates/prover/src/examples/fibonacci/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::BaseField;
use crate::core::poly::circle::CircleEvaluation;
use crate::core::poly::BitReversedOrder;
use crate::core::{ColumnVec, ComponentVec, InteractionElements};
use crate::core::{ColumnVec, InteractionElements};

pub struct FibonacciAir {
pub component: FibonacciComponent,
Expand Down Expand Up @@ -38,8 +38,8 @@ impl AirTraceWriter<CpuBackend> for FibonacciAir {
&self,
_trace: &ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ComponentVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
ComponentVec(vec![vec![]])
) -> Vec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
vec![]
}

fn to_air_prover(&self) -> &impl AirProver<CpuBackend> {
Expand Down Expand Up @@ -87,8 +87,8 @@ impl AirTraceWriter<CpuBackend> for MultiFibonacciAir {
&self,
_trace: &ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ComponentVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
ComponentVec(vec![vec![]])
) -> Vec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
vec![]
}

fn to_air_prover(&self) -> &impl AirProver<CpuBackend> {
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ impl ComponentTraceWriter<CpuBackend> for FibonacciComponent {
&self,
_trace: &ColumnVec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
) -> Vec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
vec![]
}
}
Expand Down
32 changes: 24 additions & 8 deletions crates/prover/src/examples/wide_fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ use crate::core::circle::CirclePoint;
use crate::core::constraints::{coset_vanishing, point_vanishing};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE};
use crate::core::fields::FieldExpOps;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureCirclePoly};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::shifted_secure_combination;
use crate::core::{ColumnVec, InteractionElements};
Expand Down Expand Up @@ -73,7 +74,7 @@ impl Component for WideFibComponent {
fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
TreeVec::new(vec![
vec![self.log_column_size(); self.n_columns()],
vec![self.log_column_size(); 1],
vec![self.log_column_size(); SECURE_EXTENSION_DEGREE],
])
}

Expand All @@ -83,7 +84,7 @@ impl Component for WideFibComponent {
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
TreeVec::new(vec![
fixed_mask_points(&vec![vec![0_usize]; self.n_columns()], point),
vec![vec![point]],
vec![vec![point]; SECURE_EXTENSION_DEGREE],
])
}

Expand All @@ -100,7 +101,11 @@ impl Component for WideFibComponent {
) {
let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset;
let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]);
let lookup_numerator = (mask[self.n_columns()][0]
let lookup_value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|i| {
mask[self.n_columns() + i][0]
}));
let lookup_numerator = (lookup_value
* shifted_secure_combination(
&[mask[self.n_columns() - 2][0], mask[self.n_columns() - 1][0]],
alpha,
Expand All @@ -124,13 +129,24 @@ impl ComponentTraceWriter<CpuBackend> for WideFibComponent {
&self,
trace: &ColumnVec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
let interaction_trace_domain = trace[0].domain;
) -> Vec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
let trace_values = trace.iter().map(|eval| &eval.values[..]).collect_vec();
let (alpha, z) = (elements[ALPHA_ID], elements[Z_ID]);
let values = write_lookup_column(&trace_values, alpha, z);
let eval = CircleEvaluation::new(interaction_trace_domain, values);
vec![eval]
let mut secure_column = SecureColumn::<CpuBackend>::zeros(values.len());
for (i, value) in values.into_iter().enumerate() {
secure_column.set(i, value);
}
secure_column
.columns
.into_iter()
.map(|eval| {
CircleEvaluation::<CpuBackend, BaseField, BitReversedOrder>::new(
trace[0].domain,
eval,
)
})
.collect_vec()
}
}

Expand Down
19 changes: 11 additions & 8 deletions crates/prover/src/examples/wide_fibonacci/constraint_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ use crate::core::constraints::{coset_vanishing, point_vanishing};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureCirclePoly};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse, shifted_secure_combination};
use crate::core::{ColumnVec, ComponentVec, InteractionElements};
use crate::core::{ColumnVec, InteractionElements};
use crate::examples::wide_fibonacci::component::LOG_N_COLUMNS;

// TODO(AlonH): Rename file to `cpu.rs`.

impl AirTraceVerifier for WideFibAir {
fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements {
let ids = self.component.interaction_element_ids();
let elements = channel.draw_felts(ids.len()).into_iter().map(|e| e.0 .0);
let elements = channel.draw_felts(ids.len());
InteractionElements::new(BTreeMap::from_iter(zip_eq(ids, elements)))
}
}
Expand All @@ -37,10 +37,9 @@ impl AirTraceWriter<CpuBackend> for WideFibAir {
&self,
trace: &ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> ComponentVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
ComponentVec(vec![self
.component
.write_interaction_trace(&trace.iter().collect(), elements)])
) -> Vec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
self.component
.write_interaction_trace(&trace.iter().collect(), elements)
}

fn to_air_prover(&self) -> &impl AirProver<CpuBackend> {
Expand Down Expand Up @@ -92,8 +91,12 @@ impl ComponentProver<CpuBackend> for WideFibComponent {
}

// Lookup constraints.
let lookup_value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|j| {
trace_evals[1][j][i].into()
}));
lookup_numerators[i] = accum.random_coeff_powers[self.n_columns() - 2]
* ((trace_evals[1][0][i]
* ((lookup_value
* shifted_secure_combination(
&[
trace_evals[0][self.n_columns() - 2][i],
Expand Down
26 changes: 13 additions & 13 deletions crates/prover/src/examples/wide_fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod tests {
use crate::core::backend::CpuBackend;
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::IntoSlice;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::CanonicCoset;
Expand All @@ -40,15 +41,15 @@ mod tests {
}

pub fn assert_constraints_on_lookup_column(
column: &[BaseField],
column: &[SecureField],
input_trace: &[Vec<BaseField>],
alpha: BaseField,
z: BaseField,
alpha: SecureField,
z: SecureField,
) {
let n_columns = input_trace.len();
let column_length = column.len();
assert_eq!(column_length, input_trace[0].len());
let mut prev_value = BaseField::one();
let mut prev_value = SecureField::one();
for (i, cell) in column.iter().enumerate() {
assert_eq!(
*cell
Expand Down Expand Up @@ -110,8 +111,8 @@ mod tests {
b: m31!(1),
};

let alpha = m31!(7);
let z = m31!(11);
let alpha = qm31!(7, 1, 3, 4);
let z = qm31!(11, 1, 2, 3);
let trace = gen_trace(&wide_fib, vec![input]);
let input_trace = trace.iter().map(|values| &values[..]).collect_vec();
let lookup_column = write_lookup_column(&input_trace, alpha, z);
Expand Down Expand Up @@ -163,15 +164,14 @@ mod tests {
.iter()
.cloned()
.enumerate()
.map(|(i, id)| (id, m31!(43 + i as u32))),
.map(|(i, id)| (id, qm31!(43 + i as u32, 1, 2, 3))),
));
let interaction_trace =
wide_fib.write_interaction_trace(&trace.iter().collect(), &interaction_elements);

let interaction_poly = interaction_trace
.iter()
.map(|trace| trace.clone().interpolate())
let interaction_poly = wide_fib
.write_interaction_trace(&trace.iter().collect(), &interaction_elements)
.into_iter()
.map(|eval| eval.interpolate())
.collect_vec();

let interaction_trace = interaction_poly
.iter()
.map(|poly| poly.evaluate(eval_domain))
Expand Down
Loading

0 comments on commit 88f8a6d

Please sign in to comment.