Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use SecureField in interaction. #642

Merged
merged 1 commit into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/prover/src/core/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::fields::qm31::SecureField;
use super::pcs::TreeVec;
use super::poly::circle::{CircleEvaluation, CirclePoly};
use super::poly::BitReversedOrder;
use super::{ColumnVec, ComponentVec, InteractionElements};
use super::{ColumnVec, InteractionElements};

pub mod accumulation;
mod air_ext;
Expand All @@ -34,7 +34,7 @@ pub trait AirTraceWriter<B: Backend>: AirTraceVerifier {
&self,
trace: &ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> ComponentVec<CircleEvaluation<B, BaseField, BitReversedOrder>>;
) -> Vec<CircleEvaluation<B, BaseField, BitReversedOrder>>;

fn to_air_prover(&self) -> &impl AirProver<B>;
}
Expand Down
8 changes: 4 additions & 4 deletions crates/prover/src/core/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::BTreeMap;
use std::ops::{Deref, DerefMut, Index};

use self::fields::m31::BaseField;
use self::fields::qm31::SecureField;

pub mod air;
pub mod backend;
Expand Down Expand Up @@ -62,10 +62,10 @@ impl<T> DerefMut for ComponentVec<T> {
}

#[derive(Default)]
pub struct InteractionElements(BTreeMap<String, BaseField>);
pub struct InteractionElements(BTreeMap<String, SecureField>);

impl InteractionElements {
pub fn new(elements: BTreeMap<String, BaseField>) -> Self {
pub fn new(elements: BTreeMap<String, SecureField>) -> Self {
Self(elements)
}

Expand All @@ -75,7 +75,7 @@ impl InteractionElements {
}

impl Index<&str> for InteractionElements {
type Output = BaseField;
type Output = SecureField;

fn index(&self, index: &str) -> &Self::Output {
// TODO(AlonH): Return an error if the key is not found.
Expand Down
1 change: 1 addition & 0 deletions crates/prover/src/core/poly/circle/secure_poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ impl<B: PolyOps> SecureCirclePoly<B> {

/// Evaluates the polynomial at a point, given evaluations of its composing base field
/// polynomials at that point.
// TODO(AlonH): Move to SecureField and rename.
pub fn eval_from_partial_evals(evals: [SecureField; SECURE_EXTENSION_DEGREE]) -> SecureField {
let mut res = evals[0];
res += evals[1] * SecureField::from_u32_unchecked(0, 1, 0, 0);
Expand Down
27 changes: 12 additions & 15 deletions crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub fn evaluate_and_commit_on_trace<B: Backend + MerkleOps<MerkleHasher>>(
let trace_polys = trace
.clone()
.into_iter()
.map(|poly| poly.interpolate_with_twiddles(twiddles))
.map(|eval| eval.interpolate_with_twiddles(twiddles))
.collect();
span.exit();

Expand All @@ -71,17 +71,14 @@ pub fn evaluate_and_commit_on_trace<B: Backend + MerkleOps<MerkleHasher>>(
span.exit();

let interaction_elements = air.interaction_elements(channel);
let interaction_traces = air.interact(&trace, &interaction_elements);
let interaction_trace_polys = interaction_traces
.0
.into_iter()
.flat_map(|trace| {
trace
.into_iter()
.map(|poly| poly.interpolate_with_twiddles(twiddles))
})
.collect_vec();
if !interaction_trace_polys.is_empty() {
let interaction_trace = air.interact(&trace, &interaction_elements);
if !interaction_trace.is_empty() {
let span = span!(Level::INFO, "Interaction trace interpolation").entered();
let interaction_trace_polys = interaction_trace
.into_iter()
.map(|eval| eval.interpolate_with_twiddles(twiddles))
.collect();
span.exit();
commitment_scheme.commit(interaction_trace_polys, channel, twiddles);
}

Expand Down Expand Up @@ -356,7 +353,7 @@ mod tests {
use crate::core::poly::BitReversedOrder;
use crate::core::prover::{prove, ProvingError};
use crate::core::test_utils::test_channel;
use crate::core::{ColumnVec, ComponentVec, InteractionElements};
use crate::core::{ColumnVec, InteractionElements};
use crate::qm31;

struct TestAir<C: ComponentProver<CpuBackend>> {
Expand All @@ -380,8 +377,8 @@ mod tests {
&self,
_trace: &ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ComponentVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
ComponentVec(vec![vec![]])
) -> Vec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
vec![]
}

fn to_air_prover(&self) -> &impl AirProver<CpuBackend> {
Expand Down
14 changes: 9 additions & 5 deletions crates/prover/src/core/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::iter::Peekable;
use std::ops::Add;

use num_traits::One;
use num_traits::{One, Zero};

use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
Expand Down Expand Up @@ -92,12 +93,15 @@ pub fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec<SecureF
/// Alpha and z should be secure field elements for soundness.
pub fn shifted_secure_combination<F: ExtensionOf<BaseField>>(
values: &[F],
alpha: BaseField,
z: BaseField,
) -> F {
alpha: SecureField,
z: SecureField,
) -> SecureField
where
SecureField: Add<F, Output = SecureField>,
{
let res = values
.iter()
.fold(F::zero(), |acc, &value| acc * alpha + value);
.fold(SecureField::zero(), |acc, &value| acc * alpha + value);
res - z
}

Expand Down
10 changes: 5 additions & 5 deletions crates/prover/src/examples/fibonacci/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::BaseField;
use crate::core::poly::circle::CircleEvaluation;
use crate::core::poly::BitReversedOrder;
use crate::core::{ColumnVec, ComponentVec, InteractionElements};
use crate::core::{ColumnVec, InteractionElements};

pub struct FibonacciAir {
pub component: FibonacciComponent,
Expand Down Expand Up @@ -38,8 +38,8 @@ impl AirTraceWriter<CpuBackend> for FibonacciAir {
&self,
_trace: &ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ComponentVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
ComponentVec(vec![vec![]])
) -> Vec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
vec![]
}

fn to_air_prover(&self) -> &impl AirProver<CpuBackend> {
Expand Down Expand Up @@ -87,8 +87,8 @@ impl AirTraceWriter<CpuBackend> for MultiFibonacciAir {
&self,
_trace: &ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ComponentVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
ComponentVec(vec![vec![]])
) -> Vec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
vec![]
}

fn to_air_prover(&self) -> &impl AirProver<CpuBackend> {
Expand Down
28 changes: 21 additions & 7 deletions crates/prover/src/examples/wide_fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ use crate::core::circle::CirclePoint;
use crate::core::constraints::{coset_vanishing, point_vanishing};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE};
use crate::core::fields::FieldExpOps;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureCirclePoly};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::shifted_secure_combination;
use crate::core::{ColumnVec, InteractionElements};
Expand Down Expand Up @@ -73,7 +74,7 @@ impl Component for WideFibComponent {
fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
TreeVec::new(vec![
vec![self.log_column_size(); self.n_columns()],
vec![self.log_column_size(); 1],
vec![self.log_column_size(); SECURE_EXTENSION_DEGREE],
])
}

Expand All @@ -83,7 +84,7 @@ impl Component for WideFibComponent {
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
TreeVec::new(vec![
fixed_mask_points(&vec![vec![0_usize]; self.n_columns()], point),
vec![vec![point]],
vec![vec![point]; SECURE_EXTENSION_DEGREE],
])
}

Expand All @@ -100,7 +101,11 @@ impl Component for WideFibComponent {
) {
let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset;
let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]);
let lookup_numerator = (mask[self.n_columns()][0]
let lookup_value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|i| {
mask[self.n_columns() + i][0]
}));
let lookup_numerator = (lookup_value
* shifted_secure_combination(
&[mask[self.n_columns() - 2][0], mask[self.n_columns() - 1][0]],
alpha,
Expand All @@ -125,12 +130,21 @@ impl ComponentTraceWriter<CpuBackend> for WideFibComponent {
trace: &ColumnVec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
let interaction_trace_domain = trace[0].domain;
let trace_values = trace.iter().map(|eval| &eval.values[..]).collect_vec();
let (alpha, z) = (elements[ALPHA_ID], elements[Z_ID]);
// TODO(AlonH): Return a secure column directly.
let values = write_lookup_column(&trace_values, alpha, z);
let eval = CircleEvaluation::new(interaction_trace_domain, values);
vec![eval]
let secure_column: SecureColumn<CpuBackend> = values.into_iter().collect();
secure_column
.columns
.into_iter()
.map(|eval| {
CircleEvaluation::<CpuBackend, BaseField, BitReversedOrder>::new(
trace[0].domain,
eval,
)
})
.collect_vec()
}
}

Expand Down
19 changes: 11 additions & 8 deletions crates/prover/src/examples/wide_fibonacci/constraint_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ use crate::core::constraints::{coset_vanishing, point_vanishing};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureCirclePoly};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse, shifted_secure_combination};
use crate::core::{ColumnVec, ComponentVec, InteractionElements};
use crate::core::{ColumnVec, InteractionElements};
use crate::examples::wide_fibonacci::component::LOG_N_COLUMNS;

// TODO(AlonH): Rename file to `cpu.rs`.

impl AirTraceVerifier for WideFibAir {
fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements {
let ids = self.component.interaction_element_ids();
let elements = channel.draw_felts(ids.len()).into_iter().map(|e| e.0 .0);
let elements = channel.draw_felts(ids.len());
InteractionElements::new(BTreeMap::from_iter(zip_eq(ids, elements)))
}
}
Expand All @@ -37,10 +37,9 @@ impl AirTraceWriter<CpuBackend> for WideFibAir {
&self,
trace: &ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> ComponentVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
ComponentVec(vec![self
.component
.write_interaction_trace(&trace.iter().collect(), elements)])
) -> Vec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
self.component
.write_interaction_trace(&trace.iter().collect(), elements)
}

fn to_air_prover(&self) -> &impl AirProver<CpuBackend> {
Expand Down Expand Up @@ -92,8 +91,12 @@ impl ComponentProver<CpuBackend> for WideFibComponent {
}

// Lookup constraints.
let lookup_value =
SecureCirclePoly::<CpuBackend>::eval_from_partial_evals(std::array::from_fn(|j| {
trace_evals[1][j][i].into()
}));
lookup_numerators[i] = accum.random_coeff_powers[self.n_columns() - 2]
* ((trace_evals[1][0][i]
* ((lookup_value
* shifted_secure_combination(
&[
trace_evals[0][self.n_columns() - 2][i],
Expand Down
26 changes: 13 additions & 13 deletions crates/prover/src/examples/wide_fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod tests {
use crate::core::backend::CpuBackend;
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::IntoSlice;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::CanonicCoset;
Expand All @@ -40,15 +41,15 @@ mod tests {
}

pub fn assert_constraints_on_lookup_column(
column: &[BaseField],
column: &[SecureField],
input_trace: &[Vec<BaseField>],
alpha: BaseField,
z: BaseField,
alpha: SecureField,
z: SecureField,
) {
let n_columns = input_trace.len();
let column_length = column.len();
assert_eq!(column_length, input_trace[0].len());
let mut prev_value = BaseField::one();
let mut prev_value = SecureField::one();
for (i, cell) in column.iter().enumerate() {
assert_eq!(
*cell
Expand Down Expand Up @@ -110,8 +111,8 @@ mod tests {
b: m31!(1),
};

let alpha = m31!(7);
let z = m31!(11);
let alpha = qm31!(7, 1, 3, 4);
let z = qm31!(11, 1, 2, 3);
let trace = gen_trace(&wide_fib, vec![input]);
let input_trace = trace.iter().map(|values| &values[..]).collect_vec();
let lookup_column = write_lookup_column(&input_trace, alpha, z);
Expand Down Expand Up @@ -163,15 +164,14 @@ mod tests {
.iter()
.cloned()
.enumerate()
.map(|(i, id)| (id, m31!(43 + i as u32))),
.map(|(i, id)| (id, qm31!(43 + i as u32, 1, 2, 3))),
));
let interaction_trace =
wide_fib.write_interaction_trace(&trace.iter().collect(), &interaction_elements);

let interaction_poly = interaction_trace
.iter()
.map(|trace| trace.clone().interpolate())
let interaction_poly = wide_fib
.write_interaction_trace(&trace.iter().collect(), &interaction_elements)
.into_iter()
.map(|eval| eval.interpolate())
.collect_vec();

let interaction_trace = interaction_poly
.iter()
.map(|poly| poly.evaluate(eval_domain))
Expand Down
6 changes: 3 additions & 3 deletions crates/prover/src/examples/wide_fibonacci/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::core::fields::{FieldExpOps, FieldOps};
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::{ColumnVec, ComponentVec, InteractionElements};
use crate::core::{ColumnVec, InteractionElements};
use crate::examples::wide_fibonacci::component::{ALPHA_ID, N_COLUMNS, Z_ID};

// TODO(AlonH): Remove this once the Cpu and Simd implementations are aligned.
Expand Down Expand Up @@ -70,8 +70,8 @@ impl AirTraceWriter<SimdBackend> for SimdWideFibAir {
&self,
_trace: &ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ComponentVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
ComponentVec(vec![vec![]])
) -> Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
vec![]
}

fn to_air_prover(&self) -> &impl AirProver<SimdBackend> {
Expand Down
10 changes: 5 additions & 5 deletions crates/prover/src/examples/wide_fibonacci/trace_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use num_traits::One;

use super::component::Input;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::utils::shifted_secure_combination;

Expand All @@ -27,13 +28,12 @@ pub fn write_trace_row(
/// the shifted secure combination of the last two elements in each row.
pub fn write_lookup_column(
input_trace: &[&[BaseField]],
// TODO(AlonH): Change alpha and z to SecureField.
alpha: BaseField,
z: BaseField,
) -> Vec<BaseField> {
alpha: SecureField,
z: SecureField,
) -> Vec<SecureField> {
let n_rows = input_trace[0].len();
let n_columns = input_trace.len();
let mut prev_value = BaseField::one();
let mut prev_value = SecureField::one();
(0..n_rows)
.map(|i| {
let numerator =
Expand Down
Loading