From ad27c0bec4959df4623ad740b470e613414ba59c Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 7 Jan 2025 13:15:09 +0800 Subject: [PATCH] mpcs trival migrate to p3 field --- Cargo.lock | 2 + mpcs/Cargo.toml | 2 + mpcs/src/basefold.rs | 11 ++-- mpcs/src/basefold/encoding.rs | 4 +- mpcs/src/basefold/encoding/basecode.rs | 14 ++--- mpcs/src/basefold/encoding/rs.rs | 86 ++++++++++++-------------- mpcs/src/sum_check.rs | 17 ++--- mpcs/src/sum_check/classic.rs | 63 ++++++++++--------- mpcs/src/sum_check/classic/coeff.rs | 6 +- mpcs/src/util.rs | 33 +++++----- mpcs/src/util/arithmetic.rs | 55 ++++------------ sumcheck/src/test.rs | 2 +- 12 files changed, 137 insertions(+), 158 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e1a9dbc8d..dabade28e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1141,6 +1141,8 @@ dependencies = [ "multilinear_extensions", "num-bigint", "num-integer", + "p3-field", + "p3-goldilocks", "plonky2", "poseidon", "rand", diff --git a/mpcs/Cargo.toml b/mpcs/Cargo.toml index f977328cc..bb64b3cf8 100644 --- a/mpcs/Cargo.toml +++ b/mpcs/Cargo.toml @@ -25,6 +25,8 @@ num-bigint = "0.4" num-integer = "0.1" plonky2.workspace = true poseidon.workspace = true +p3-field.workspace = true +p3-goldilocks.workspace = true rand.workspace = true rand_chacha.workspace = true rayon = { workspace = true, optional = true } diff --git a/mpcs/src/basefold.rs b/mpcs/src/basefold.rs index 6204ed038..34ae908d0 100644 --- a/mpcs/src/basefold.rs +++ b/mpcs/src/basefold.rs @@ -594,7 +594,7 @@ where evals.iter().map(Evaluation::value), &evals .iter() - .map(|eval| E::from(1 << (num_vars - points[eval.point()].len()))) + .map(|eval| E::from_canonical_u64(1 << (num_vars - points[eval.point()].len()))) .collect_vec(), &poly_iter_ext(&eq_xt).take(evals.len()).collect_vec(), ); @@ -645,8 +645,8 @@ where inner_product( &poly_iter_ext(poly).collect_vec(), build_eq_x_r_vec(point).iter(), - ) * scalar - * E::from(1 << (num_vars - poly.num_vars)) + ) * *scalar + * E::from_canonical_u64(1 << (num_vars - poly.num_vars)) // When this polynomial is smaller, it will be repeatedly summed over the cosets of the hypercube }) .sum::(); @@ -977,7 +977,7 @@ where evals.iter().map(Evaluation::value), &evals .iter() - .map(|eval| E::from(1 << (num_vars - points[eval.point()].len()))) + .map(|eval| E::from_canonical_u64(1 << (num_vars - points[eval.point()].len()))) .collect_vec(), &poly_iter_ext(&eq_xt).take(evals.len()).collect_vec(), ); @@ -1174,6 +1174,8 @@ where #[cfg(test)] mod test { + use ff_ext::GoldilocksExt2; + use crate::{ basefold::Basefold, test_util::{ @@ -1181,7 +1183,6 @@ mod test { run_commit_open_verify, run_simple_batch_commit_open_verify, }, }; - use goldilocks::GoldilocksExt2; use super::{BasefoldRSParams, structure::BasefoldBasecodeParams}; diff --git a/mpcs/src/basefold/encoding.rs b/mpcs/src/basefold/encoding.rs index 410d35970..6c3a03d2f 100644 --- a/mpcs/src/basefold/encoding.rs +++ b/mpcs/src/basefold/encoding.rs @@ -173,7 +173,9 @@ pub(crate) mod test_util { pub fn test_codeword_folding>() { let num_vars = 12; - let poly: Vec = (0..(1 << num_vars)).map(|i| E::from(i)).collect(); + let poly: Vec = (0..(1 << num_vars)) + .map(|i| E::from_canonical_u64(i)) + .collect(); let mut poly = FieldType::Ext(poly); let pp: Code::PublicParameters = Code::setup(num_vars); diff --git a/mpcs/src/basefold/encoding/basecode.rs b/mpcs/src/basefold/encoding/basecode.rs index 9fbee84f1..04ea5892d 100644 --- a/mpcs/src/basefold/encoding/basecode.rs +++ b/mpcs/src/basefold/encoding/basecode.rs @@ -10,10 +10,10 @@ use crate::{ }; use aes::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}; use ark_std::{end_timer, start_timer}; -use ff::{BatchInvert, Field, PrimeField}; use ff_ext::ExtensionField; use generic_array::GenericArray; use multilinear_extensions::mle::FieldType; +use p3_field::{Field, FieldAlgebra, batch_multiplicative_inverse}; use rand::SeedableRng; use rayon::prelude::{ParallelIterator, ParallelSlice, ParallelSliceMut}; @@ -216,7 +216,7 @@ where let x0: E::BaseField = query_root_table_from_rng_aes::(level, index, &mut cipher); let x1 = -x0; - let w = (x1 - x0).invert().unwrap(); + let w = (x1 - x0).try_inverse().unwrap(); (E::from(x0), E::from(x1), E::from(w)) } @@ -351,13 +351,13 @@ pub fn get_table_aes( assert_eq!(flat_table.len(), 1 << lg_n); // Multiply -2 to every element to get the weights. Now weights = { -2x } - let mut weights: Vec = flat_table + let weights: Vec = flat_table .par_iter() .map(|el| E::BaseField::ZERO - *el - *el) .collect(); // Then invert all the elements. Now weights = { -1/2x } - BatchInvert::batch_invert(&mut weights); + let weights = batch_multiplicative_inverse(&weights); // Zip x and -1/2x together. The result is the list { (x, -1/2x) } // What is this -1/2x? It is used in linear interpolation over the domain (x, -x), which @@ -399,13 +399,13 @@ pub fn query_root_table_from_rng_aes( } let pos = ((level_offset + (reverse_bits(index, level) as u128)) - * ((E::BaseField::NUM_BITS as usize).next_power_of_two() as u128)) + * ((E::BaseField::bits() as usize).next_power_of_two() as u128)) .checked_div(8) .unwrap(); cipher.seek(pos); - let bytes = (E::BaseField::NUM_BITS as usize).next_power_of_two() / 8; + let bytes = (E::BaseField::bits() as usize).next_power_of_two() / 8; let mut dest: Vec = vec![0u8; bytes]; cipher.apply_keystream(&mut dest); @@ -417,7 +417,7 @@ mod tests { use crate::basefold::encoding::test_util::test_codeword_folding; use super::*; - use goldilocks::GoldilocksExt2; + use ff_ext::GoldilocksExt2; use multilinear_extensions::mle::DenseMultilinearExtension; #[test] diff --git a/mpcs/src/basefold/encoding/rs.rs b/mpcs/src/basefold/encoding/rs.rs index 2bcac0826..9ae955d28 100644 --- a/mpcs/src/basefold/encoding/rs.rs +++ b/mpcs/src/basefold/encoding/rs.rs @@ -7,9 +7,9 @@ use crate::{ vec_mut, }; use ark_std::{end_timer, start_timer}; -use ff::{Field, PrimeField}; use ff_ext::ExtensionField; use multilinear_extensions::mle::FieldType; +use p3_field::{Field, FieldAlgebra, PrimeField}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; @@ -71,8 +71,7 @@ fn ifft( let n = poly.len(); let lg_n = log2_strict(n); let n_inv = (E::BaseField::ONE + E::BaseField::ONE) - .invert() - .unwrap() + .inverse() .pow([lg_n as u64]); fft(poly, zero_factor, root_table); @@ -310,13 +309,13 @@ where } let mut gamma_powers = Vec::with_capacity(max_message_size_log); let mut gamma_powers_inv = Vec::with_capacity(max_message_size_log); - gamma_powers.push(E::BaseField::MULTIPLICATIVE_GENERATOR); - gamma_powers_inv.push(E::BaseField::MULTIPLICATIVE_GENERATOR.invert().unwrap()); + gamma_powers.push(E::BaseField::GENERATOR); + gamma_powers_inv.push(E::BaseField::GENERATOR.inverse()); for i in 1..max_message_size_log + Spec::get_rate_log() { gamma_powers.push(gamma_powers[i - 1].square()); gamma_powers_inv.push(gamma_powers_inv[i - 1].square()); } - let inv_of_two = E::BaseField::from(2).invert().unwrap(); + let inv_of_two = E::BaseField::from_canonical_u64(2).inverse(); gamma_powers_inv.iter_mut().for_each(|x| *x *= inv_of_two); pp.fft_root_table .truncate(max_message_size_log + Spec::get_rate_log()); @@ -493,7 +492,7 @@ impl RSCode { let k = 1 << (full_message_size_log - lg_m); coset_fft( &mut ret, - E::BaseField::MULTIPLICATIVE_GENERATOR.pow([k]), + E::BaseField::GENERATOR.pow([k]), Spec::get_rate_log(), fft_root_table, ); @@ -514,7 +513,7 @@ impl RSCode { let x0 = E::BaseField::ROOT_OF_UNITY .pow([1 << (E::BaseField::S - (level as u32 + 1))]) .pow([index as u64]) - * E::BaseField::MULTIPLICATIVE_GENERATOR + * E::BaseField::GENERATOR .pow([1 << (full_message_size_log + Spec::get_rate_log() - level - 1)]); let x1 = -x0; let w = (x1 - x0).invert().unwrap(); @@ -546,19 +545,24 @@ fn naive_fft(poly: &[E], rate: usize, shift: E::BaseField) -> #[cfg(test)] mod tests { + use ff_ext::GoldilocksExt2; + use p3_goldilocks::Goldilocks; + use crate::{ basefold::encoding::test_util::test_codeword_folding, util::{field_type_index_ext, plonky2_util::reverse_index_bits_in_place_field_type}, }; + use ff_ext::FromUniformBytes; use super::*; - use goldilocks::{Goldilocks, GoldilocksExt2}; #[test] fn test_naive_fft() { let num_vars = 5; - let poly: Vec = (0..(1 << num_vars)).map(GoldilocksExt2::from).collect(); + let poly: Vec = (0..(1 << num_vars)) + .map(GoldilocksExt2::from_canonical_u64) + .collect(); let mut poly2 = FieldType::Ext(poly.clone()); let naive = naive_fft::(&poly, 1, Goldilocks::ONE); @@ -583,15 +587,10 @@ mod tests { .collect(); let mut poly2 = FieldType::Ext(poly.clone()); - let naive = naive_fft::(&poly, 1, Goldilocks::MULTIPLICATIVE_GENERATOR); + let naive = naive_fft::(&poly, 1, Goldilocks::GENERATOR); let root_table = fft_root_table(num_vars); - coset_fft::( - &mut poly2, - Goldilocks::MULTIPLICATIVE_GENERATOR, - 0, - &root_table, - ); + coset_fft::(&mut poly2, Goldilocks::GENERATOR, 0, &root_table); let poly2 = match poly2 { FieldType::Ext(coeffs) => coeffs, @@ -613,19 +612,10 @@ mod tests { poly2.as_mut_slice()[..poly.len()].copy_from_slice(poly.as_slice()); let mut poly2 = FieldType::Ext(poly2.clone()); - let naive = naive_fft::( - &poly, - 1 << rate_bits, - Goldilocks::MULTIPLICATIVE_GENERATOR, - ); + let naive = naive_fft::(&poly, 1 << rate_bits, Goldilocks::GENERATOR); let root_table = fft_root_table(num_vars + rate_bits); - coset_fft::( - &mut poly2, - Goldilocks::MULTIPLICATIVE_GENERATOR, - rate_bits, - &root_table, - ); + coset_fft::(&mut poly2, Goldilocks::GENERATOR, rate_bits, &root_table); let poly2 = match poly2 { FieldType::Ext(coeffs) => coeffs, @@ -638,7 +628,9 @@ mod tests { fn test_ifft() { let num_vars = 5; - let poly: Vec = (0..(1 << num_vars)).map(GoldilocksExt2::from).collect(); + let poly: Vec = (0..(1 << num_vars)) + .map(GoldilocksExt2::from_canonical_u64) + .collect(); let mut poly = FieldType::Ext(poly); let original = poly.clone(); @@ -686,14 +678,14 @@ mod tests { pub fn test_colinearity() { let num_vars = 10; - let poly: Vec = (0..(1 << num_vars)).map(E::from).collect(); + let poly: Vec = (0..(1 << num_vars)).map(E::from_canonical_u64).collect(); let poly = FieldType::Ext(poly); let pp = >::setup(num_vars); let (pp, _) = Code::trim(pp, num_vars).unwrap(); let mut codeword = Code::encode(&pp, &poly); reverse_index_bits_in_place_field_type(&mut codeword); - let challenge = E::from(2); + let challenge = E::from_canonical_u64(2); let folded_codeword = Code::fold_bitreversed_codeword(&pp, &codeword, challenge); let codeword = match codeword { FieldType::Ext(coeffs) => coeffs, @@ -712,8 +704,8 @@ mod tests { // which is equivalent to // (x0-challenge)*(b[1]-a) = (x1-challenge)*(b[0]-a) assert_eq!( - (x0 - challenge) * (b[1] - a), - (x1 - challenge) * (b[0] - a), + (x0 - challenge) * (b[1] - *a), + (x1 - challenge) * (b[0] - *a), "failed for i = {}", i ); @@ -724,7 +716,7 @@ mod tests { pub fn test_low_degree() { let num_vars = 10; - let poly: Vec = (0..(1 << num_vars)).map(E::from).collect(); + let poly: Vec = (0..(1 << num_vars)).map(E::from_canonical_u64).collect(); let poly = FieldType::Ext(poly); let pp = >::setup(num_vars); @@ -789,7 +781,7 @@ mod tests { "check low degree of (left-right)*omega^(-i)", ); - let challenge = E::from(2); + let challenge = E::from_canonical_u64(2); let folded_codeword = Code::fold_bitreversed_codeword(&pp, &codeword, challenge); let c_fold = folded_codeword[0]; let c_fold1 = folded_codeword[folded_codeword.len() >> 1]; @@ -800,7 +792,7 @@ mod tests { // The top level folding coefficient should have shift factor gamma let folding_coeffs = Code::prover_folding_coeffs(&pp, log2_strict(codeword.len()) - 1, 0); - assert_eq!(folding_coeffs.0, E::from(F::MULTIPLICATIVE_GENERATOR)); + assert_eq!(folding_coeffs.0, E::from(F::GENERATOR)); assert_eq!(folding_coeffs.0 + folding_coeffs.1, E::ZERO); assert_eq!( (folding_coeffs.1 - folding_coeffs.0) * folding_coeffs.2, @@ -815,17 +807,16 @@ mod tests { // So the folded value should be equal to // (gamma^{-1} * alpha * (c0 - c_mid) + (c0 + c_mid)) / 2 assert_eq!( - c_fold * F::MULTIPLICATIVE_GENERATOR * F::from(2), - challenge * (c0 - c_mid) + (c0 + c_mid) * F::MULTIPLICATIVE_GENERATOR + c_fold * F::GENERATOR * F::from_canonical_u64(2), + challenge * (c0 - c_mid) + (c0 + c_mid) * F::GENERATOR ); assert_eq!( - c_fold * F::MULTIPLICATIVE_GENERATOR * F::from(2), - challenge * left_right_diff[0] + left_right_sum[0] * F::MULTIPLICATIVE_GENERATOR + c_fold * F::GENERATOR * F::from_canonical_u64(2), + challenge * left_right_diff[0] + left_right_sum[0] * F::GENERATOR ); assert_eq!( - c_fold * F::from(2), - challenge * left_right_diff[0] * F::MULTIPLICATIVE_GENERATOR.invert().unwrap() - + left_right_sum[0] + c_fold * F::from_canonical_u64(2), + challenge * left_right_diff[0] * F::GENERATOR.inverse() + left_right_sum[0] ); let folding_coeffs = Code::prover_folding_coeffs(&pp, log2_strict(codeword.len()) - 1, 1); @@ -835,8 +826,7 @@ mod tests { assert_eq!(root_of_unity.pow([(codeword.len() >> 1) as u64]), -F::ONE); assert_eq!( folding_coeffs.0, - E::from(F::MULTIPLICATIVE_GENERATOR) - * E::from(root_of_unity).pow([(codeword.len() >> 2) as u64]) + E::from(F::GENERATOR) * E::from(root_of_unity).pow([(codeword.len() >> 2) as u64]) ); assert_eq!(folding_coeffs.0 + folding_coeffs.1, E::ZERO); assert_eq!( @@ -849,14 +839,14 @@ mod tests { // The coefficients are respectively 1/2 and gamma^{-1}/2 * alpha. // In another word, the folded codeword multipled by 2 is the linear // combination by coeffs: 1 and gamma^{-1} * alpha - let gamma_inv = F::MULTIPLICATIVE_GENERATOR.invert().unwrap(); + let gamma_inv = F::GENERATOR.inverse(); let b = challenge * gamma_inv; let folded_codeword_vec = match &folded_codeword { FieldType::Ext(coeffs) => coeffs.clone(), _ => panic!("Wrong field type"), }; assert_eq!( - c_fold * F::from(2), + c_fold * F::from_canonical_u64(2), left_right_diff[0] * b + left_right_sum[0] ); for (i, (c, (diff, sum))) in folded_codeword_vec @@ -864,7 +854,7 @@ mod tests { .zip(left_right_diff.iter().zip(left_right_sum.iter())) .enumerate() { - assert_eq!(*c + c, *sum + b * diff, "failed for i = {}", i); + assert_eq!(*c + *c, *sum + b * *diff, "failed for i = {}", i); } check_low_degree(&folded_codeword, "low degree check for folded"); diff --git a/mpcs/src/sum_check.rs b/mpcs/src/sum_check.rs index f2fcf0e47..7406025ca 100644 --- a/mpcs/src/sum_check.rs +++ b/mpcs/src/sum_check.rs @@ -9,10 +9,10 @@ use crate::{ use std::{collections::HashMap, fmt::Debug}; use classic::{ClassicSumCheckRoundMessage, SumcheckProof}; -use ff::PrimeField; use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::mle::DenseMultilinearExtension; +use p3_field::Field; use serde::{Serialize, de::DeserializeOwned}; use transcript::Transcript; @@ -113,27 +113,30 @@ pub fn evaluate( ) } -pub fn lagrange_eval(x: &[F], b: usize) -> F { +pub fn lagrange_eval(x: &[F], b: usize) -> F { assert!(!x.is_empty()); product(x.iter().enumerate().map( |(idx, x_i)| { - if b.nth_bit(idx) { *x_i } else { F::ONE - x_i } + if b.nth_bit(idx) { *x_i } else { F::ONE - *x_i } }, )) } -pub fn eq_xy_eval(x: &[F], y: &[F]) -> F { +pub fn eq_xy_eval(x: &[F], y: &[F]) -> F { assert!(!x.is_empty()); assert_eq!(x.len(), y.len()); product( x.iter() .zip(y) - .map(|(x_i, y_i)| (*x_i * y_i).double() + F::ONE - x_i - y_i), + .map(|(x_i, y_i)| (*x_i * *y_i).double() + F::ONE - *x_i - *y_i), ) } -fn identity_eval(x: &[F]) -> F { - inner_product(x, &powers(F::from(2)).take(x.len()).collect_vec()) +fn identity_eval(x: &[F]) -> F { + inner_product( + x, + &powers(F::from_canonical_u64(2)).take(x.len()).collect_vec(), + ) } diff --git a/mpcs/src/sum_check/classic.rs b/mpcs/src/sum_check/classic.rs index f99d832df..ea7bfc92f 100644 --- a/mpcs/src/sum_check/classic.rs +++ b/mpcs/src/sum_check/classic.rs @@ -9,7 +9,6 @@ use crate::{ }, }; use ark_std::{end_timer, start_timer}; -use ff::Field; use ff_ext::ExtensionField; use itertools::Itertools; use num_integer::Integer; @@ -24,6 +23,7 @@ use multilinear_extensions::{ pub(crate) use coeff::Coefficients; pub use coeff::CoefficientsProver; +use p3_field::FieldAlgebra; #[derive(Debug)] pub struct ProverState<'a, E: ExtensionField> { @@ -99,12 +99,12 @@ impl<'a, E: ExtensionField> ProverState<'a, E> { fn next_round(&mut self, sum: E, challenge: &E) { self.sum = sum; - self.identity += E::from(1 << self.round) * challenge; + self.identity += E::from_canonical_u64(1 << self.round) * *challenge; self.lagranges.values_mut().for_each(|(b, value)| { if b.is_even() { - *value *= &(E::ONE - challenge); + *value *= E::ONE - *challenge; } else { - *value *= challenge; + *value *= *challenge; } *b >>= 1; }); @@ -324,51 +324,58 @@ mod tests { use transcript::BasicTranscript; use super::*; - use goldilocks::{Goldilocks as Fr, GoldilocksExt2 as E}; + use ff_ext::GoldilocksExt2 as E; + use p3_goldilocks::{Goldilocks as Fr, MdsMatrixGoldilocks}; #[test] fn test_sum_check_protocol() { let polys = [ DenseMultilinearExtension::::from_evaluations_vec(2, vec![ - Fr::from(1), - Fr::from(2), - Fr::from(3), - Fr::from(4), + Fr::from_canonical_u64(1), + Fr::from_canonical_u64(2), + Fr::from_canonical_u64(3), + Fr::from_canonical_u64(4), ]), DenseMultilinearExtension::from_evaluations_vec(2, vec![ - Fr::from(0), - Fr::from(1), - Fr::from(1), - Fr::from(0), + Fr::from_canonical_u64(0), + Fr::from_canonical_u64(1), + Fr::from_canonical_u64(1), + Fr::from_canonical_u64(0), ]), - DenseMultilinearExtension::from_evaluations_vec(1, vec![Fr::from(0), Fr::from(1)]), + DenseMultilinearExtension::from_evaluations_vec(1, vec![ + Fr::from_canonical_u64(0), + Fr::from_canonical_u64(1), + ]), + ]; + let points = vec![ + vec![E::from_canonical_u64(1), E::from_canonical_u64(2)], + vec![E::from_canonical_u64(1)], ]; - let points = vec![vec![E::from(1), E::from(2)], vec![E::from(1)]]; let expression = Expression::::eq_xy(0) * Expression::Polynomial(Query::new(0, Rotation::cur())) - * E::from(Fr::from(2)) + * E::from(Fr::from_canonical_u64(2)) + Expression::::eq_xy(0) * Expression::Polynomial(Query::new(1, Rotation::cur())) - * E::from(Fr::from(3)) + * E::from(Fr::from_canonical_u64(3)) + Expression::::eq_xy(1) * Expression::Polynomial(Query::new(2, Rotation::cur())) - * E::from(Fr::from(4)); + * E::from(Fr::from_canonical_u64(4)); let virtual_poly = VirtualPolynomial::::new(&expression, polys.iter(), &[], points.as_slice()); let sum = inner_product( &poly_iter_ext(&polys[0]).collect_vec(), &build_eq_x_r_vec(&points[0]), - ) * Fr::from(2) + ) * Fr::from_canonical_u64(2) + inner_product( &poly_iter_ext(&polys[1]).collect_vec(), &build_eq_x_r_vec(&points[0]), - ) * Fr::from(3) + ) * Fr::from_canonical_u64(3) + inner_product( &poly_iter_ext(&polys[2]).collect_vec(), &build_eq_x_r_vec(&points[1]), - ) * Fr::from(4) - * Fr::from(2); // The third polynomial is summed twice because the hypercube is larger - let mut transcript = BasicTranscript::::new(b"sumcheck"); + ) * Fr::from_canonical_u64(4) + * Fr::from_canonical_u64(2); // The third polynomial is summed twice because the hypercube is larger + let mut transcript = BasicTranscript::::new(b"sumcheck"); let (challenges, evals, proof) = > as SumCheck>::prove( &(), @@ -383,7 +390,7 @@ mod tests { assert_eq!(polys[1].evaluate(&challenges), evals[1]); assert_eq!(polys[2].evaluate(&challenges[..1]), evals[2]); - let mut transcript = BasicTranscript::::new(b"sumcheck"); + let mut transcript = BasicTranscript::::new(b"sumcheck"); let (new_sum, verifier_challenges) = > as SumCheck< E, @@ -395,12 +402,12 @@ mod tests { assert_eq!(verifier_challenges, challenges); assert_eq!( new_sum, - evals[0] * eq_xy_eval(&points[0], &challenges[..2]) * Fr::from(2) - + evals[1] * eq_xy_eval(&points[0], &challenges[..2]) * Fr::from(3) - + evals[2] * eq_xy_eval(&points[1], &challenges[..1]) * Fr::from(4) + evals[0] * eq_xy_eval(&points[0], &challenges[..2]) * Fr::from_canonical_u64(2) + + evals[1] * eq_xy_eval(&points[0], &challenges[..2]) * Fr::from_canonical_u64(3) + + evals[2] * eq_xy_eval(&points[1], &challenges[..1]) * Fr::from_canonical_u64(4) ); - let mut transcript = BasicTranscript::::new(b"sumcheck"); + let mut transcript = BasicTranscript::::new(b"sumcheck"); > as SumCheck>::verify( &(), diff --git a/mpcs/src/sum_check/classic/coeff.rs b/mpcs/src/sum_check/classic/coeff.rs index 10d5c1c20..b49d32e73 100644 --- a/mpcs/src/sum_check/classic/coeff.rs +++ b/mpcs/src/sum_check/classic/coeff.rs @@ -267,10 +267,10 @@ impl CoefficientsProver { .for_each(|((lhs_0, lhs_1), (rhs_0, rhs_1))| { let coeff_0 = lhs_0 * rhs_0; let coeff_2 = (lhs_1 - lhs_0) * (rhs_1 - rhs_0); - coeffs[0] += &coeff_0; - coeffs[2] += &coeff_2; + coeffs[0] += coeff_0; + coeffs[2] += coeff_2; if !LAZY { - coeffs[1] += &(lhs_1 * rhs_1 - coeff_0 - coeff_2); + coeffs[1] += lhs_1 * rhs_1 - coeff_0 - coeff_2; } }); }; diff --git a/mpcs/src/util.rs b/mpcs/src/util.rs index 7688b53ec..49be09983 100644 --- a/mpcs/src/util.rs +++ b/mpcs/src/util.rs @@ -3,14 +3,13 @@ pub mod expression; pub mod hash; pub mod parallel; pub mod plonky2_util; -use ff::{Field, PrimeField}; -use ff_ext::ExtensionField; -use goldilocks::SmallField; +use ff_ext::{ExtensionField, SmallField}; use itertools::{Either, Itertools, izip}; use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; pub mod merkle_tree; use crate::{Error, util::parallel::parallelize}; +use p3_field::{FieldAlgebra, PrimeField}; pub use plonky2_util::log2_strict; pub fn ext_to_usize(x: &E) -> usize { @@ -23,7 +22,7 @@ pub fn base_to_usize(x: &E::BaseField) -> usize { } pub fn u32_to_field(x: u32) -> E::BaseField { - E::BaseField::from(x as u64) + E::BaseField::from_canonical_u32(x) } pub trait BitIndex { @@ -38,7 +37,7 @@ impl BitIndex for usize { /// How many bytes are required to store n field elements? pub fn num_of_bytes(n: usize) -> usize { - (F::NUM_BITS as usize).next_power_of_two() * n / 8 + (F::bits() as usize).next_power_of_two() * n / 8 } macro_rules! impl_index { @@ -118,8 +117,8 @@ pub fn field_type_index_mul_base( scalar: &E::BaseField, ) { match poly { - FieldType::Ext(coeffs) => coeffs[index] *= scalar, - FieldType::Base(coeffs) => coeffs[index] *= scalar, + FieldType::Ext(coeffs) => coeffs[index] *= *scalar, + FieldType::Base(coeffs) => coeffs[index] *= *scalar, _ => unreachable!(), } } @@ -194,13 +193,13 @@ pub fn multiply_poly(poly: &mut DenseMultilinearExtension, match &mut poly.evaluations { FieldType::Ext(coeffs) => { for coeff in coeffs.iter_mut() { - *coeff *= scalar; + *coeff *= *scalar; } } FieldType::Base(coeffs) => { *poly = DenseMultilinearExtension::::from_evaluations_ext_vec( poly.num_vars, - coeffs.iter().map(|x| E::from(*x) * scalar).collect(), + coeffs.iter().map(|x| E::from(*x) * *scalar).collect(), ); } _ => unreachable!(), @@ -320,11 +319,12 @@ pub fn ext_try_into_base(x: &E) -> Result(mut rng: impl RngCore) -> [F; N] { + pub fn rand_array(mut rng: impl RngCore) -> [F; N] { array::from_fn(|_| F::random(&mut rng)) } - pub fn rand_vec(n: usize, mut rng: impl RngCore) -> Vec { + pub fn rand_vec(n: usize, mut rng: impl RngCore) -> Vec { iter::repeat_with(|| F::random(&mut rng)).take(n).collect() } #[test] pub fn test_field_transform() { - assert_eq!(F::from(2) * F::from(3), F::from(6)); + assert_eq!( + F::from_canonical_u64(2) * F::from_canonical_u64(3), + F::from_canonical_u64(6) + ); assert_eq!(base_to_usize::(&u32_to_field::(1u32)), 1); assert_eq!(base_to_usize::(&u32_to_field::(10u32)), 10); } diff --git a/mpcs/src/util/arithmetic.rs b/mpcs/src/util/arithmetic.rs index 609f65455..657d76646 100644 --- a/mpcs/src/util/arithmetic.rs +++ b/mpcs/src/util/arithmetic.rs @@ -1,8 +1,7 @@ -use ff::{BatchInvert, Field, PrimeField}; - use ff_ext::ExtensionField; use multilinear_extensions::mle::FieldType; use num_integer::Integer; +use p3_field::{Field, PrimeField}; use std::{borrow::Borrow, iter}; mod bh; @@ -13,6 +12,7 @@ pub use hypercube::{ interpolate_field_type_over_boolean_hypercube, interpolate_over_boolean_hypercube, }; use num_bigint::BigUint; +use p3_field::FieldAlgebra; use itertools::Itertools; @@ -29,7 +29,7 @@ pub fn horner(coeffs: &[F], x: &F) -> F { let coeff_vec: Vec<&F> = coeffs.iter().rev().collect(); let mut acc = F::ZERO; for c in coeff_vec { - acc = acc * x + c; + acc = acc * *x + *c; } acc // 2 @@ -40,7 +40,7 @@ pub fn horner(coeffs: &[F], x: &F) -> F { pub fn horner_base(coeffs: &[E::BaseField], x: &E) -> E { let mut acc = E::ZERO; for c in coeffs.iter().rev() { - acc = acc * x + E::from(*c); + acc = acc * *x + E::from(*c); } acc // 2 @@ -52,11 +52,11 @@ pub fn steps(start: F) -> impl Iterator { } pub fn steps_by(start: F, step: F) -> impl Iterator { - iter::successors(Some(start), move |state| Some(step + state)) + iter::successors(Some(start), move |state| Some(step + *state)) } pub fn powers(scalar: F) -> impl Iterator { - iter::successors(Some(F::ONE), move |power| Some(scalar * power)) + iter::successors(Some(F::ONE), move |power| Some(scalar * *power)) } pub fn squares(scalar: F) -> impl Iterator { @@ -66,13 +66,13 @@ pub fn squares(scalar: F) -> impl Iterator { pub fn product(values: impl IntoIterator>) -> F { values .into_iter() - .fold(F::ONE, |acc, value| acc * value.borrow()) + .fold(F::ONE, |acc, value| acc * *value.borrow()) } pub fn sum(values: impl IntoIterator>) -> F { values .into_iter() - .fold(F::ZERO, |acc, value| acc + value.borrow()) + .fold(F::ZERO, |acc, value| acc + *value.borrow()) } pub fn inner_product<'a, 'b, F: Field>( @@ -81,7 +81,7 @@ pub fn inner_product<'a, 'b, F: Field>( ) -> F { lhs.into_iter() .zip_eq(rhs) - .map(|(lhs, rhs)| *lhs * rhs) + .map(|(lhs, rhs)| *lhs * *rhs) .reduce(|acc, product| acc + product) .unwrap_or_default() } @@ -94,42 +94,11 @@ pub fn inner_product_three<'a, 'b, 'c, F: Field>( a.into_iter() .zip_eq(b) .zip_eq(c) - .map(|((a, b), c)| *a * b * c) + .map(|((a, b), c)| *a * *b * *c) .reduce(|acc, product| acc + product) .unwrap_or_default() } -pub fn barycentric_weights(points: &[F]) -> Vec { - let mut weights = points - .iter() - .enumerate() - .map(|(j, point_j)| { - points - .iter() - .enumerate() - .filter(|&(i, _point_i)| (i != j)) - .map(|(_i, point_i)| *point_j - point_i) - .reduce(|acc, value| acc * value) - .unwrap_or(F::ONE) - }) - .collect_vec(); - weights.iter_mut().batch_invert(); - weights -} - -pub fn barycentric_interpolate(weights: &[F], points: &[F], evals: &[F], x: &F) -> F { - let (coeffs, sum_inv) = { - let mut coeffs = points.iter().map(|point| *x - point).collect_vec(); - coeffs.iter_mut().batch_invert(); - coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| { - *coeff *= weight; - }); - let sum_inv = coeffs.iter().fold(F::ZERO, |sum, coeff| sum + coeff); - (coeffs, sum_inv.invert().unwrap()) - }; - inner_product(&coeffs, evals) * sum_inv -} - pub fn modulus() -> BigUint { BigUint::from_bytes_le((-F::ONE).to_repr().as_ref()) + 1u64 } @@ -215,7 +184,7 @@ pub fn interpolate2(points: [(F, F); 2], x: F) -> F { let (a0, a1) = points[0]; let (b0, b1) = points[1]; assert_ne!(a0, b0); - a1 + (x - a0) * (b1 - a1) * (b0 - a0).invert().unwrap() + a1 + (x - a0) * (b1 - a1) * (b0 - a0).inverse() } pub fn degree_2_zero_plus_one(poly: &[F]) -> F { @@ -229,7 +198,7 @@ pub fn degree_2_eval(poly: &[F], point: F) -> F { pub fn base_from_raw_bytes(bytes: &[u8]) -> E::BaseField { let mut res = E::BaseField::ZERO; bytes.iter().for_each(|b| { - res += E::BaseField::from(u64::from(*b)); + res += E::BaseField::from_canonical_u8(*b); }); res } diff --git a/sumcheck/src/test.rs b/sumcheck/src/test.rs index 64184fcb4..d3850d873 100644 --- a/sumcheck/src/test.rs +++ b/sumcheck/src/test.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use ark_std::{rand::RngCore, test_rng}; use ff_ext::{ExtensionField, GoldilocksExt2}; use multilinear_extensions::virtual_poly::VirtualPolynomial; -use p3_field::FieldAlgebra; +use p3_field::{Field, FieldAlgebra}; use p3_goldilocks::MdsMatrixGoldilocks; use p3_mds::MdsPermutation; use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};