Skip to content

Commit

Permalink
Eval framework preperation
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Jul 9, 2024
1 parent a9bb818 commit 126d63f
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 24 deletions.
24 changes: 22 additions & 2 deletions crates/prover/src/core/backend/simd/m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ use bytemuck::{Pod, Zeroable};
use num_traits::{One, Zero};
use rand::distributions::{Distribution, Standard};

use super::qm31::PackedQM31;
use crate::core::backend::simd::utils::{InterleaveEvens, InterleaveOdds};
use crate::core::fields::m31::{pow2147483645, BaseField, M31, P};
use crate::core::fields::qm31::QM31;
use crate::core::fields::FieldExpOps;

pub const LOG_N_LANES: u32 = 4;
Expand Down Expand Up @@ -149,15 +151,33 @@ impl Mul for PackedM31 {
}
}

impl Mul<BaseField> for PackedM31 {
impl Mul<M31> for PackedM31 {
type Output = Self;

#[inline(always)]
fn mul(self, rhs: BaseField) -> Self::Output {
fn mul(self, rhs: M31) -> Self::Output {
self * PackedM31::broadcast(rhs)
}
}

impl Add<QM31> for PackedM31 {
type Output = PackedQM31;

#[inline(always)]
fn add(self, rhs: QM31) -> Self::Output {
PackedQM31::broadcast(rhs) + self
}
}

impl Mul<QM31> for PackedM31 {
type Output = PackedQM31;

#[inline(always)]
fn mul(self, rhs: QM31) -> Self::Output {
PackedQM31::broadcast(rhs) * self
}
}

impl MulAssign for PackedM31 {
#[inline(always)]
fn mul_assign(&mut self, rhs: Self) {
Expand Down
24 changes: 24 additions & 0 deletions crates/prover/src/core/backend/simd/qm31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,30 @@ impl Sub<PackedM31> for PackedQM31 {
}
}

impl Add<QM31> for PackedQM31 {
type Output = Self;

fn add(self, rhs: QM31) -> Self::Output {
self + PackedQM31::broadcast(rhs)
}
}

impl Sub<QM31> for PackedQM31 {
type Output = Self;

fn sub(self, rhs: QM31) -> Self::Output {
self - PackedQM31::broadcast(rhs)
}
}

impl Mul<QM31> for PackedQM31 {
type Output = Self;

fn mul(self, rhs: QM31) -> Self::Output {
self * PackedQM31::broadcast(rhs)
}
}

impl SubAssign for PackedQM31 {
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
Expand Down
8 changes: 8 additions & 0 deletions crates/prover/src/core/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ impl<F: Field> CirclePoint<F> {
y: self.y.into(),
}
}

pub fn mul_signed(&self, off: isize) -> CirclePoint<F> {
if off > 0 {
self.mul(off as u128)
} else {
self.conjugate().mul(-off as u128)
}
}
}

impl<F: Field> Add for CirclePoint<F> {
Expand Down
18 changes: 10 additions & 8 deletions crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,6 @@ pub fn generate_proof<B: Backend + MerkleOps<MerkleHasher>>(
) -> Result<StarkProof, ProvingError> {
let component_traces = air.component_traces(&commitment_scheme.trees);
let lookup_values = air.lookup_values(&component_traces);
channel.mix_felts(
&lookup_values
.0
.values()
.map(|v| SecureField::from(*v))
.collect_vec(),
);

// Evaluate and commit on composition polynomial.
let random_coeff = channel.draw_felt();
Expand Down Expand Up @@ -190,8 +183,17 @@ pub fn prove<B: Backend + MerkleOps<MerkleHasher>>(
let (mut commitment_scheme, interaction_elements) =
evaluate_and_commit_on_trace(air, channel, &twiddles, trace)?;

let air = air.to_air_prover();
channel.mix_felts(
&air.lookup_values(&air.component_traces(&commitment_scheme.trees))
.0
.values()
.map(|v| SecureField::from(*v))
.collect_vec(),
);

generate_proof(
&air.to_air_prover(),
&air,
channel,
&interaction_elements,
&twiddles,
Expand Down
60 changes: 46 additions & 14 deletions crates/prover/src/core/utils.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::iter::Peekable;
use std::ops::Add;
use std::ops::{Add, Mul, Sub};

use num_traits::{One, Zero};

use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
use super::fields::ExtensionOf;

pub trait IteratorMutExt<'a, T: 'a>: Iterator<Item = &'a mut T> {
fn assign(self, other: impl IntoIterator<Item = T>)
Expand Down Expand Up @@ -68,14 +67,23 @@ pub(crate) fn previous_bit_reversed_circle_domain_index(
domain_log_size: u32,
eval_log_size: u32,
) -> usize {
assert!(domain_log_size < eval_log_size);
let step_size = 1 << (eval_log_size - domain_log_size - 1) as usize;
offset_bit_reversed_circle_domain_index(i, domain_log_size, eval_log_size, -1)
}

pub(crate) fn offset_bit_reversed_circle_domain_index(
i: usize,
domain_log_size: u32,
eval_log_size: u32,
offset: isize,
) -> usize {
let mut prev_index = bit_reverse_index(i, eval_log_size);
let half_size = 1 << (eval_log_size - 1);
let step_size = offset * (1 << (eval_log_size - domain_log_size - 1)) as isize;
if prev_index < half_size {
prev_index = (prev_index + half_size - step_size) % half_size;
prev_index = (prev_index as isize + step_size).rem_euclid(half_size as isize) as usize;
} else {
prev_index = ((prev_index + step_size) % half_size) + half_size;
prev_index =
((prev_index as isize - step_size).rem_euclid(half_size as isize) as usize) + half_size;
}
bit_reverse_index(prev_index, eval_log_size)
}
Expand Down Expand Up @@ -136,17 +144,13 @@ pub fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec<SecureF

/// Securely combines the given values using the given random alpha and z.
/// Alpha and z should be secure field elements for soundness.
pub fn shifted_secure_combination<F: ExtensionOf<BaseField>>(
values: &[F],
alpha: SecureField,
z: SecureField,
) -> SecureField
pub fn shifted_secure_combination<F: Copy, EF>(values: &[F], alpha: EF, z: EF) -> EF
where
SecureField: Add<F, Output = SecureField>,
EF: Copy + Zero + Mul<EF, Output = EF> + Add<F, Output = EF> + Sub<EF, Output = EF>,
{
let res = values
.iter()
.fold(SecureField::zero(), |acc, &value| acc * alpha + value);
.fold(EF::zero(), |acc, &value| acc * alpha + value);
res - z
}

Expand All @@ -155,12 +159,15 @@ mod tests {
use itertools::Itertools;
use num_traits::One;

use super::{
offset_bit_reversed_circle_domain_index, previous_bit_reversed_circle_domain_index,
};
use crate::core::backend::cpu::CpuCircleEvaluation;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::CanonicCoset;
use crate::core::poly::NaturalOrder;
use crate::core::utils::{bit_reverse, previous_bit_reversed_circle_domain_index};
use crate::core::utils::bit_reverse;
use crate::{m31, qm31};

#[test]
Expand Down Expand Up @@ -200,6 +207,31 @@ mod tests {
assert_eq!(powers, vec![]);
}

#[test]
fn test_offset_bit_reversed_circle_domain_index() {
let domain_log_size = 3;
let eval_log_size = 6;
let initial_index = 5;

let actual = offset_bit_reversed_circle_domain_index(
initial_index,
domain_log_size,
eval_log_size,
-2,
);
let expected_prev = previous_bit_reversed_circle_domain_index(
initial_index,
domain_log_size,
eval_log_size,
);
let expected_prev2 = previous_bit_reversed_circle_domain_index(
expected_prev,
domain_log_size,
eval_log_size,
);
assert_eq!(actual, expected_prev2);
}

#[test]
fn test_previous_bit_reversed_circle_domain_index() {
let log_size = 4;
Expand Down

0 comments on commit 126d63f

Please sign in to comment.