Skip to content

Commit

Permalink
mpcs trival migrate to p3 field
Browse files Browse the repository at this point in the history
  • Loading branch information
hero78119 committed Jan 7, 2025
1 parent f0079fb commit 8d744fe
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 158 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions mpcs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ num-bigint = "0.4"
num-integer = "0.1"
plonky2.workspace = true
poseidon.workspace = true
p3-field.workspace = true
p3-goldilocks.workspace = true
rand.workspace = true
rand_chacha.workspace = true
rayon = { workspace = true, optional = true }
Expand Down
11 changes: 6 additions & 5 deletions mpcs/src/basefold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ where
evals.iter().map(Evaluation::value),
&evals
.iter()
.map(|eval| E::from(1 << (num_vars - points[eval.point()].len())))
.map(|eval| E::from_canonical_u64(1 << (num_vars - points[eval.point()].len())))
.collect_vec(),
&poly_iter_ext(&eq_xt).take(evals.len()).collect_vec(),
);
Expand Down Expand Up @@ -645,8 +645,8 @@ where
inner_product(
&poly_iter_ext(poly).collect_vec(),
build_eq_x_r_vec(point).iter(),
) * scalar
* E::from(1 << (num_vars - poly.num_vars))
) * *scalar
* E::from_canonical_u64(1 << (num_vars - poly.num_vars))
// When this polynomial is smaller, it will be repeatedly summed over the cosets of the hypercube
})
.sum::<E>();
Expand Down Expand Up @@ -977,7 +977,7 @@ where
evals.iter().map(Evaluation::value),
&evals
.iter()
.map(|eval| E::from(1 << (num_vars - points[eval.point()].len())))
.map(|eval| E::from_canonical_u64(1 << (num_vars - points[eval.point()].len())))
.collect_vec(),
&poly_iter_ext(&eq_xt).take(evals.len()).collect_vec(),
);
Expand Down Expand Up @@ -1174,14 +1174,15 @@ where

#[cfg(test)]
mod test {
use ff_ext::GoldilocksExt2;

use crate::{
basefold::Basefold,
test_util::{
gen_rand_poly_base, gen_rand_poly_ext, run_batch_commit_open_verify,
run_commit_open_verify, run_simple_batch_commit_open_verify,
},
};
use goldilocks::GoldilocksExt2;

use super::{BasefoldRSParams, structure::BasefoldBasecodeParams};

Expand Down
4 changes: 3 additions & 1 deletion mpcs/src/basefold/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ pub(crate) mod test_util {
pub fn test_codeword_folding<E: ExtensionField, Code: EncodingScheme<E>>() {
let num_vars = 12;

let poly: Vec<E> = (0..(1 << num_vars)).map(|i| E::from(i)).collect();
let poly: Vec<E> = (0..(1 << num_vars))
.map(|i| E::from_canonical_u64(i))
.collect();
let mut poly = FieldType::Ext(poly);

let pp: Code::PublicParameters = Code::setup(num_vars);
Expand Down
14 changes: 7 additions & 7 deletions mpcs/src/basefold/encoding/basecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ use crate::{
};
use aes::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek};
use ark_std::{end_timer, start_timer};
use ff::{BatchInvert, Field, PrimeField};
use ff_ext::ExtensionField;
use generic_array::GenericArray;
use multilinear_extensions::mle::FieldType;
use p3_field::{Field, FieldAlgebra, batch_multiplicative_inverse};
use rand::SeedableRng;
use rayon::prelude::{ParallelIterator, ParallelSlice, ParallelSliceMut};

Expand Down Expand Up @@ -216,7 +216,7 @@ where
let x0: E::BaseField = query_root_table_from_rng_aes::<E>(level, index, &mut cipher);
let x1 = -x0;

let w = (x1 - x0).invert().unwrap();
let w = (x1 - x0).try_inverse().unwrap();

(E::from(x0), E::from(x1), E::from(w))
}
Expand Down Expand Up @@ -351,13 +351,13 @@ pub fn get_table_aes<E: ExtensionField, Rng: RngCore + Clone>(
assert_eq!(flat_table.len(), 1 << lg_n);

// Multiply -2 to every element to get the weights. Now weights = { -2x }
let mut weights: Vec<E::BaseField> = flat_table
let weights: Vec<E::BaseField> = flat_table
.par_iter()
.map(|el| E::BaseField::ZERO - *el - *el)
.collect();

// Then invert all the elements. Now weights = { -1/2x }
BatchInvert::batch_invert(&mut weights);
let weights = batch_multiplicative_inverse(&weights);

// Zip x and -1/2x together. The result is the list { (x, -1/2x) }
// What is this -1/2x? It is used in linear interpolation over the domain (x, -x), which
Expand Down Expand Up @@ -399,13 +399,13 @@ pub fn query_root_table_from_rng_aes<E: ExtensionField>(
}

let pos = ((level_offset + (reverse_bits(index, level) as u128))
* ((E::BaseField::NUM_BITS as usize).next_power_of_two() as u128))
* ((E::BaseField::bits() as usize).next_power_of_two() as u128))
.checked_div(8)
.unwrap();

cipher.seek(pos);

let bytes = (E::BaseField::NUM_BITS as usize).next_power_of_two() / 8;
let bytes = (E::BaseField::bits() as usize).next_power_of_two() / 8;
let mut dest: Vec<u8> = vec![0u8; bytes];
cipher.apply_keystream(&mut dest);

Expand All @@ -417,7 +417,7 @@ mod tests {
use crate::basefold::encoding::test_util::test_codeword_folding;

use super::*;
use goldilocks::GoldilocksExt2;
use ff_ext::GoldilocksExt2;
use multilinear_extensions::mle::DenseMultilinearExtension;

#[test]
Expand Down
86 changes: 38 additions & 48 deletions mpcs/src/basefold/encoding/rs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use crate::{
vec_mut,
};
use ark_std::{end_timer, start_timer};
use ff::{Field, PrimeField};
use ff_ext::ExtensionField;
use multilinear_extensions::mle::FieldType;
use p3_field::{Field, FieldAlgebra, PrimeField};

use serde::{Deserialize, Serialize, de::DeserializeOwned};

Expand Down Expand Up @@ -71,8 +71,7 @@ fn ifft<E: ExtensionField>(
let n = poly.len();
let lg_n = log2_strict(n);
let n_inv = (E::BaseField::ONE + E::BaseField::ONE)
.invert()
.unwrap()
.inverse()
.pow([lg_n as u64]);

fft(poly, zero_factor, root_table);
Expand Down Expand Up @@ -310,13 +309,13 @@ where
}
let mut gamma_powers = Vec::with_capacity(max_message_size_log);
let mut gamma_powers_inv = Vec::with_capacity(max_message_size_log);
gamma_powers.push(E::BaseField::MULTIPLICATIVE_GENERATOR);
gamma_powers_inv.push(E::BaseField::MULTIPLICATIVE_GENERATOR.invert().unwrap());
gamma_powers.push(E::BaseField::GENERATOR);
gamma_powers_inv.push(E::BaseField::GENERATOR.inverse());
for i in 1..max_message_size_log + Spec::get_rate_log() {
gamma_powers.push(gamma_powers[i - 1].square());
gamma_powers_inv.push(gamma_powers_inv[i - 1].square());
}
let inv_of_two = E::BaseField::from(2).invert().unwrap();
let inv_of_two = E::BaseField::from_canonical_u64(2).inverse();
gamma_powers_inv.iter_mut().for_each(|x| *x *= inv_of_two);
pp.fft_root_table
.truncate(max_message_size_log + Spec::get_rate_log());
Expand Down Expand Up @@ -493,7 +492,7 @@ impl<Spec: RSCodeSpec> RSCode<Spec> {
let k = 1 << (full_message_size_log - lg_m);
coset_fft(
&mut ret,
E::BaseField::MULTIPLICATIVE_GENERATOR.pow([k]),
E::BaseField::GENERATOR.pow([k]),
Spec::get_rate_log(),
fft_root_table,
);
Expand All @@ -514,7 +513,7 @@ impl<Spec: RSCodeSpec> RSCode<Spec> {
let x0 = E::BaseField::ROOT_OF_UNITY
.pow([1 << (E::BaseField::S - (level as u32 + 1))])
.pow([index as u64])
* E::BaseField::MULTIPLICATIVE_GENERATOR
* E::BaseField::GENERATOR
.pow([1 << (full_message_size_log + Spec::get_rate_log() - level - 1)]);
let x1 = -x0;
let w = (x1 - x0).invert().unwrap();
Expand Down Expand Up @@ -546,19 +545,24 @@ fn naive_fft<E: ExtensionField>(poly: &[E], rate: usize, shift: E::BaseField) ->

#[cfg(test)]
mod tests {
use ff_ext::GoldilocksExt2;
use p3_goldilocks::Goldilocks;

use crate::{
basefold::encoding::test_util::test_codeword_folding,
util::{field_type_index_ext, plonky2_util::reverse_index_bits_in_place_field_type},
};
use ff_ext::FromUniformBytes;

use super::*;
use goldilocks::{Goldilocks, GoldilocksExt2};

#[test]
fn test_naive_fft() {
let num_vars = 5;

let poly: Vec<GoldilocksExt2> = (0..(1 << num_vars)).map(GoldilocksExt2::from).collect();
let poly: Vec<GoldilocksExt2> = (0..(1 << num_vars))
.map(GoldilocksExt2::from_canonical_u64)
.collect();
let mut poly2 = FieldType::Ext(poly.clone());

let naive = naive_fft::<GoldilocksExt2>(&poly, 1, Goldilocks::ONE);
Expand All @@ -583,15 +587,10 @@ mod tests {
.collect();
let mut poly2 = FieldType::Ext(poly.clone());

let naive = naive_fft::<GoldilocksExt2>(&poly, 1, Goldilocks::MULTIPLICATIVE_GENERATOR);
let naive = naive_fft::<GoldilocksExt2>(&poly, 1, Goldilocks::GENERATOR);

let root_table = fft_root_table(num_vars);
coset_fft::<GoldilocksExt2>(
&mut poly2,
Goldilocks::MULTIPLICATIVE_GENERATOR,
0,
&root_table,
);
coset_fft::<GoldilocksExt2>(&mut poly2, Goldilocks::GENERATOR, 0, &root_table);

let poly2 = match poly2 {
FieldType::Ext(coeffs) => coeffs,
Expand All @@ -613,19 +612,10 @@ mod tests {
poly2.as_mut_slice()[..poly.len()].copy_from_slice(poly.as_slice());
let mut poly2 = FieldType::Ext(poly2.clone());

let naive = naive_fft::<GoldilocksExt2>(
&poly,
1 << rate_bits,
Goldilocks::MULTIPLICATIVE_GENERATOR,
);
let naive = naive_fft::<GoldilocksExt2>(&poly, 1 << rate_bits, Goldilocks::GENERATOR);

let root_table = fft_root_table(num_vars + rate_bits);
coset_fft::<GoldilocksExt2>(
&mut poly2,
Goldilocks::MULTIPLICATIVE_GENERATOR,
rate_bits,
&root_table,
);
coset_fft::<GoldilocksExt2>(&mut poly2, Goldilocks::GENERATOR, rate_bits, &root_table);

let poly2 = match poly2 {
FieldType::Ext(coeffs) => coeffs,
Expand All @@ -638,7 +628,9 @@ mod tests {
fn test_ifft() {
let num_vars = 5;

let poly: Vec<GoldilocksExt2> = (0..(1 << num_vars)).map(GoldilocksExt2::from).collect();
let poly: Vec<GoldilocksExt2> = (0..(1 << num_vars))
.map(GoldilocksExt2::from_canonical_u64)
.collect();
let mut poly = FieldType::Ext(poly);
let original = poly.clone();

Expand Down Expand Up @@ -686,14 +678,14 @@ mod tests {
pub fn test_colinearity() {
let num_vars = 10;

let poly: Vec<E> = (0..(1 << num_vars)).map(E::from).collect();
let poly: Vec<E> = (0..(1 << num_vars)).map(E::from_canonical_u64).collect();
let poly = FieldType::Ext(poly);

let pp = <Code as EncodingScheme<E>>::setup(num_vars);
let (pp, _) = Code::trim(pp, num_vars).unwrap();
let mut codeword = Code::encode(&pp, &poly);
reverse_index_bits_in_place_field_type(&mut codeword);
let challenge = E::from(2);
let challenge = E::from_canonical_u64(2);
let folded_codeword = Code::fold_bitreversed_codeword(&pp, &codeword, challenge);
let codeword = match codeword {
FieldType::Ext(coeffs) => coeffs,
Expand All @@ -712,8 +704,8 @@ mod tests {
// which is equivalent to
// (x0-challenge)*(b[1]-a) = (x1-challenge)*(b[0]-a)
assert_eq!(
(x0 - challenge) * (b[1] - a),
(x1 - challenge) * (b[0] - a),
(x0 - challenge) * (b[1] - *a),
(x1 - challenge) * (b[0] - *a),
"failed for i = {}",
i
);
Expand All @@ -724,7 +716,7 @@ mod tests {
pub fn test_low_degree() {
let num_vars = 10;

let poly: Vec<E> = (0..(1 << num_vars)).map(E::from).collect();
let poly: Vec<E> = (0..(1 << num_vars)).map(E::from_canonical_u64).collect();
let poly = FieldType::Ext(poly);

let pp = <Code as EncodingScheme<E>>::setup(num_vars);
Expand Down Expand Up @@ -789,7 +781,7 @@ mod tests {
"check low degree of (left-right)*omega^(-i)",
);

let challenge = E::from(2);
let challenge = E::from_canonical_u64(2);
let folded_codeword = Code::fold_bitreversed_codeword(&pp, &codeword, challenge);
let c_fold = folded_codeword[0];
let c_fold1 = folded_codeword[folded_codeword.len() >> 1];
Expand All @@ -800,7 +792,7 @@ mod tests {

// The top level folding coefficient should have shift factor gamma
let folding_coeffs = Code::prover_folding_coeffs(&pp, log2_strict(codeword.len()) - 1, 0);
assert_eq!(folding_coeffs.0, E::from(F::MULTIPLICATIVE_GENERATOR));
assert_eq!(folding_coeffs.0, E::from(F::GENERATOR));
assert_eq!(folding_coeffs.0 + folding_coeffs.1, E::ZERO);
assert_eq!(
(folding_coeffs.1 - folding_coeffs.0) * folding_coeffs.2,
Expand All @@ -815,17 +807,16 @@ mod tests {
// So the folded value should be equal to
// (gamma^{-1} * alpha * (c0 - c_mid) + (c0 + c_mid)) / 2
assert_eq!(
c_fold * F::MULTIPLICATIVE_GENERATOR * F::from(2),
challenge * (c0 - c_mid) + (c0 + c_mid) * F::MULTIPLICATIVE_GENERATOR
c_fold * F::GENERATOR * F::from_canonical_u64(2),
challenge * (c0 - c_mid) + (c0 + c_mid) * F::GENERATOR
);
assert_eq!(
c_fold * F::MULTIPLICATIVE_GENERATOR * F::from(2),
challenge * left_right_diff[0] + left_right_sum[0] * F::MULTIPLICATIVE_GENERATOR
c_fold * F::GENERATOR * F::from_canonical_u64(2),
challenge * left_right_diff[0] + left_right_sum[0] * F::GENERATOR
);
assert_eq!(
c_fold * F::from(2),
challenge * left_right_diff[0] * F::MULTIPLICATIVE_GENERATOR.invert().unwrap()
+ left_right_sum[0]
c_fold * F::from_canonical_u64(2),
challenge * left_right_diff[0] * F::GENERATOR.inverse() + left_right_sum[0]
);

let folding_coeffs = Code::prover_folding_coeffs(&pp, log2_strict(codeword.len()) - 1, 1);
Expand All @@ -835,8 +826,7 @@ mod tests {
assert_eq!(root_of_unity.pow([(codeword.len() >> 1) as u64]), -F::ONE);
assert_eq!(
folding_coeffs.0,
E::from(F::MULTIPLICATIVE_GENERATOR)
* E::from(root_of_unity).pow([(codeword.len() >> 2) as u64])
E::from(F::GENERATOR) * E::from(root_of_unity).pow([(codeword.len() >> 2) as u64])
);
assert_eq!(folding_coeffs.0 + folding_coeffs.1, E::ZERO);
assert_eq!(
Expand All @@ -849,22 +839,22 @@ mod tests {
// The coefficients are respectively 1/2 and gamma^{-1}/2 * alpha.
// In another word, the folded codeword multipled by 2 is the linear
// combination by coeffs: 1 and gamma^{-1} * alpha
let gamma_inv = F::MULTIPLICATIVE_GENERATOR.invert().unwrap();
let gamma_inv = F::GENERATOR.inverse();
let b = challenge * gamma_inv;
let folded_codeword_vec = match &folded_codeword {
FieldType::Ext(coeffs) => coeffs.clone(),
_ => panic!("Wrong field type"),
};
assert_eq!(
c_fold * F::from(2),
c_fold * F::from_canonical_u64(2),
left_right_diff[0] * b + left_right_sum[0]
);
for (i, (c, (diff, sum))) in folded_codeword_vec
.iter()
.zip(left_right_diff.iter().zip(left_right_sum.iter()))
.enumerate()
{
assert_eq!(*c + c, *sum + b * diff, "failed for i = {}", i);
assert_eq!(*c + *c, *sum + b * *diff, "failed for i = {}", i);
}

check_low_degree(&folded_codeword, "low degree check for folded");
Expand Down
Loading

0 comments on commit 8d744fe

Please sign in to comment.