Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MleCollection for accumulating MLEs #800

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/prover/src/core/backend/simd/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ impl BaseColumn {
.map(BaseColumnMutSlice)
.collect_vec()
}

pub fn into_secure_column(self) -> SecureColumn {
let length = self.len();
let data = self.data.into_iter().map(PackedSecureField::from).collect();
SecureColumn { data, length }
}
}

impl Column<BaseField> for BaseColumn {
Expand Down
18 changes: 18 additions & 0 deletions crates/prover/src/core/fields/qm31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,24 @@ impl Mul for QM31 {
}
}

impl From<usize> for QM31 {
fn from(value: usize) -> Self {
M31::from(value).into()
}
}

impl From<u32> for QM31 {
fn from(value: u32) -> Self {
M31::from(value).into()
}
}

impl From<i32> for QM31 {
fn from(value: i32) -> Self {
M31::from(value).into()
}
}

impl TryInto<M31> for QM31 {
type Error = ();

Expand Down
8 changes: 7 additions & 1 deletion crates/prover/src/core/lookups/mle.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::Deref;
use std::ops::{Deref, DerefMut};

use educe::Educe;

Expand Down Expand Up @@ -58,6 +58,12 @@ impl<B: ColumnOps<F>, F: Field> Deref for Mle<B, F> {
}
}

impl<B: ColumnOps<F>, F: Field> DerefMut for Mle<B, F> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.evals
}
}

#[cfg(test)]
mod test {
use super::{Mle, MleOps};
Expand Down
186 changes: 186 additions & 0 deletions crates/prover/src/examples/xor/gkr_lookups/accumulation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
use std::iter::zip;
use std::ops::{AddAssign, Mul};

use educe::Educe;
use num_traits::One;

use crate::core::backend::simd::SimdBackend;
use crate::core::backend::Backend;
use crate::core::circle::M31_CIRCLE_LOG_ORDER;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::mle::Mle;
use crate::core::utils::generate_secure_powers;

pub const MIN_LOG_BLOWUP_FACTOR: u32 = 1;

/// Max number of variables for multilinear polynomials that get compiled into a univariate
/// IOP for multilinear eval at point.
pub const MAX_MLE_N_VARIABLES: u32 = M31_CIRCLE_LOG_ORDER - MIN_LOG_BLOWUP_FACTOR;

/// Accumulates [`Mle`]s grouped by their number of variables.
pub struct MleCollection<B: Backend> {
mles_by_n_variables: Vec<Option<Vec<DynMle<B>>>>,
}

impl<B: Backend> MleCollection<B> {
/// Appends an [`Mle`] to the collection.
pub fn push(&mut self, mle: impl Into<DynMle<B>>) {
let mle = mle.into();
let mles = self.mles_by_n_variables[mle.n_variables()].get_or_insert(Vec::new());
mles.push(mle);
}
}

impl MleCollection<SimdBackend> {
/// Performs a random linear combination of all MLEs, grouped by their number of variables.
///
/// MLEs are returned in ascending order by number of variables.
pub fn random_linear_combine_by_n_variables(
self,
alpha: SecureField,
) -> Vec<Mle<SimdBackend, SecureField>> {
self.mles_by_n_variables
.into_iter()
.flatten()
.map(|mles| mle_random_linear_combination(mles, alpha))
.collect()
}
}

/// # Panics
///
/// Panics if `mles` is empty or all MLEs don't have the same number of variables.
fn mle_random_linear_combination(
mles: Vec<DynMle<SimdBackend>>,
alpha: SecureField,
) -> Mle<SimdBackend, SecureField> {
assert!(!mles.is_empty());
let n_variables = mles[0].n_variables();
assert!(mles.iter().all(|mle| mle.n_variables() == n_variables));
let alpha_powers = generate_secure_powers(alpha, mles.len()).into_iter().rev();
let mut mle_and_coeff = zip(mles, alpha_powers);

// The last value can initialize the accumulator.
let (mle, coeff) = mle_and_coeff.next_back().unwrap();
assert!(coeff.is_one());
let mut acc_mle = mle.into_secure_mle();

for (mle, coeff) in mle_and_coeff {
match mle {
DynMle::Base(mle) => combine(&mut acc_mle.data, &mle.data, coeff.into()),
DynMle::Secure(mle) => combine(&mut acc_mle.data, &mle.data, coeff.into()),
}
}

acc_mle
}

/// Computes all `acc[i] += alpha * v[i]`.
pub fn combine<EF: AddAssign + Mul<F, Output = EF> + Copy, F: Copy>(
acc: &mut [EF],
v: &[F],
alpha: EF,
) {
assert_eq!(acc.len(), v.len());
zip(acc, v).for_each(|(acc, &v)| *acc += alpha * v);
}

impl<B: Backend> Default for MleCollection<B> {
fn default() -> Self {
Self {
mles_by_n_variables: vec![None; MAX_MLE_N_VARIABLES as usize + 1],
}
}
}

/// Dynamic dispatch for [`Mle`] types.
#[derive(Educe)]
#[educe(Debug, Clone)]
pub enum DynMle<B: Backend> {
Base(Mle<B, BaseField>),
Secure(Mle<B, SecureField>),
}

impl<B: Backend> DynMle<B> {
fn n_variables(&self) -> usize {
match self {
DynMle::Base(mle) => mle.n_variables(),
DynMle::Secure(mle) => mle.n_variables(),
}
}
}

impl<B: Backend> From<Mle<B, SecureField>> for DynMle<B> {
fn from(mle: Mle<B, SecureField>) -> Self {
DynMle::Secure(mle)
}
}

impl<B: Backend> From<Mle<B, BaseField>> for DynMle<B> {
fn from(mle: Mle<B, BaseField>) -> Self {
DynMle::Base(mle)
}
}

impl DynMle<SimdBackend> {
fn into_secure_mle(self) -> Mle<SimdBackend, SecureField> {
match self {
Self::Base(mle) => Mle::new(mle.into_evals().into_secure_column()),
Self::Secure(mle) => mle,
}
}
}

#[cfg(test)]
mod tests {
use std::iter::repeat;

use num_traits::Zero;

use crate::core::backend::simd::SimdBackend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::Field;
use crate::core::lookups::mle::{Mle, MleOps};
use crate::examples::xor::gkr_lookups::accumulation::MleCollection;

#[test]
fn random_linear_combine_by_n_variables() {
const SMALL_N_VARS: usize = 4;
const LARGE_N_VARS: usize = 6;
let alpha = SecureField::from(10);
let mut mle_collection = MleCollection::<SimdBackend>::default();
mle_collection.push(const_mle(SMALL_N_VARS, BaseField::from(1)));
mle_collection.push(const_mle(SMALL_N_VARS, SecureField::from(2)));
mle_collection.push(const_mle(LARGE_N_VARS, BaseField::from(3)));
mle_collection.push(const_mle(LARGE_N_VARS, SecureField::from(4)));
mle_collection.push(const_mle(LARGE_N_VARS, SecureField::from(5)));
let small_eval_point = [SecureField::zero(); SMALL_N_VARS];
let large_eval_point = [SecureField::zero(); LARGE_N_VARS];

let [small_mle, large_mle] = mle_collection
.random_linear_combine_by_n_variables(alpha)
.try_into()
.unwrap();

assert_eq!(small_mle.n_variables(), SMALL_N_VARS);
assert_eq!(large_mle.n_variables(), LARGE_N_VARS);
assert_eq!(
small_mle.eval_at_point(&small_eval_point),
SecureField::from(1) * alpha + SecureField::from(2)
);
assert_eq!(
large_mle.eval_at_point(&large_eval_point),
(SecureField::from(3) * alpha + SecureField::from(4)) * alpha + SecureField::from(5)
);
}

fn const_mle<B, F>(n_variables: usize, v: F) -> Mle<B, F>
where
B: MleOps<F>,
F: Field,
{
Mle::new(repeat(v).take(1 << n_variables).collect())
}
}
1 change: 1 addition & 0 deletions crates/prover/src/examples/xor/gkr_lookups/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod accumulation;
pub mod mle_eval;
Loading