Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BaseFold: all open functions accept Arc instead of DenseMultilinearExtension #563

Draft
wants to merge 7 commits into
base: feat/basefold-refactor-extract-1
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mpcs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ print-trace = ["ark-std/print-trace"]
sanity-check = []

[[bench]]
name = "basefold"
harness = false
name = "basefold"

[[bench]]
harness = false
Expand Down
9 changes: 7 additions & 2 deletions mpcs/benches/basefold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ fn bench_commit_open_verify_goldilocks<Pcs: PolynomialCommitmentScheme<E>>(
let eval = poly.evaluate(point.as_slice());
transcript.append_field_element_ext(&eval);
let transcript_for_bench = transcript.clone();
let poly = ArcMultilinearExtension::from(poly);
let proof = Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap();

group.bench_function(BenchmarkId::new("open", format!("{}", num_vars)), |b| {
Expand Down Expand Up @@ -143,10 +144,14 @@ fn bench_batch_commit_open_verify_goldilocks<Pcs: PolynomialCommitmentScheme<E>>
let values: Vec<E> = evals
.iter()
.map(Evaluation::value)
.map(|x| *x)
.copied()
.collect::<Vec<E>>();
transcript.append_field_element_exts(values.as_slice());
let transcript_for_bench = transcript.clone();
let polys = polys
.iter()
.map(|poly| ArcMultilinearExtension::from(poly.clone()))
.collect::<Vec<_>>();
let proof =
Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap();

Expand Down Expand Up @@ -182,7 +187,7 @@ fn bench_batch_commit_open_verify_goldilocks<Pcs: PolynomialCommitmentScheme<E>>
let values: Vec<E> = evals
.iter()
.map(Evaluation::value)
.map(|x| *x)
.copied()
.collect::<Vec<E>>();
transcript.append_field_element_exts(values.as_slice());

Expand Down
60 changes: 34 additions & 26 deletions mpcs/src/basefold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
inner_product, inner_product_three, interpolate_field_type_over_boolean_hypercube,
},
expression::{Expression, Query, Rotation},
ext_to_usize,
ext_to_usize, field_type_to_ext_vec,
hash::{Digest, write_digest_to_transcript},
log2_strict,
merkle_tree::MerkleTree,
Expand All @@ -34,7 +34,6 @@ use query_phase::{
prover_query_phase, simple_batch_prover_query_phase, simple_batch_verifier_query_phase,
verifier_query_phase,
};
use std::{borrow::BorrowMut, ops::Deref};
pub use structure::BasefoldSpec;
use structure::{BasefoldProof, ProofQueriesResultWithMerklePath};
use transcript::Transcript;
Expand All @@ -51,7 +50,6 @@ use rayon::{
iter::IntoParallelIterator,
prelude::{IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator},
};
use std::borrow::Cow;
pub use sumcheck::{one_level_eval_hc, one_level_interp_hc};

type SumCheck<F> = ClassicSumCheck<CoefficientsProver<F>>;
Expand Down Expand Up @@ -466,7 +464,7 @@ where
/// will panic.
fn open(
pp: &Self::ProverParam,
poly: &DenseMultilinearExtension<E>,
poly: &ArcMultilinearExtension<E>,
comm: &Self::CommitmentWithData,
point: &[E],
_eval: &E, // Opening does not need eval, except for sanity check
Expand All @@ -480,7 +478,7 @@ where
// the protocol won't work, and saves no verifier work anyway.
// In this case, simply return the evaluations as trivial proof.
if comm.is_trivial::<Spec>() {
return Ok(Self::Proof::trivial(vec![poly.evaluations.clone()]));
return Ok(Self::Proof::trivial(vec![poly.evaluations().clone()]));
}

assert!(comm.num_vars >= Spec::get_basecode_msg_size_log());
Expand All @@ -499,8 +497,8 @@ where
point,
comm,
transcript,
poly.num_vars,
poly.num_vars - Spec::get_basecode_msg_size_log(),
poly.num_vars(),
poly.num_vars() - Spec::get_basecode_msg_size_log(),
);

// 2. Query phase. ---------------------------------------
Expand Down Expand Up @@ -546,15 +544,15 @@ where
/// not very useful in ceno.
fn batch_open(
pp: &Self::ProverParam,
polys: &[DenseMultilinearExtension<E>],
polys: &[ArcMultilinearExtension<E>],
comms: &[Self::CommitmentWithData],
points: &[Vec<E>],
evals: &[Evaluation<E>],
transcript: &mut Transcript<E>,
) -> Result<Self::Proof, Error> {
let timer = start_timer!(|| "Basefold::batch_open");
let num_vars = polys.iter().map(|poly| poly.num_vars).max().unwrap();
let min_num_vars = polys.iter().map(|p| p.num_vars).min().unwrap();
let num_vars = polys.iter().map(|poly| poly.num_vars()).max().unwrap();
let min_num_vars = polys.iter().map(|p| p.num_vars()).min().unwrap();
assert!(min_num_vars >= Spec::get_basecode_msg_size_log());

comms.iter().for_each(|comm| {
Expand Down Expand Up @@ -603,28 +601,31 @@ where
let merged_polys = evals.iter().zip(poly_iter_ext(&eq_xt)).fold(
// This folding will generate a vector of |points| pairs of (scalar, polynomial)
// The polynomials are initialized to zero, and the scalars are initialized to one
vec![(E::ONE, Cow::<DenseMultilinearExtension<E>>::default()); points.len()],
vec![(E::ONE, Vec::<E>::new()); points.len()],
|mut merged_polys, (eval, eq_xt_i)| {
// For each polynomial to open, eval.point() specifies which point it is to be opened at.
if merged_polys[eval.point()].1.num_vars == 0 {
if merged_polys[eval.point()].1.is_empty() {
// If the accumulator for this point is still the zero polynomial,
// directly assign the random coefficient and the polynomial to open to
// this accumulator
merged_polys[eval.point()] = (eq_xt_i, Cow::Borrowed(&polys[eval.poly()]));
merged_polys[eval.point()] = (
eq_xt_i,
field_type_to_ext_vec(polys[eval.poly()].evaluations()),
);
} else {
// If the accumulator is unempty now, first force its scalar to 1, i.e.,
// make (scalar, polynomial) to (1, scalar * polynomial)
let coeff = merged_polys[eval.point()].0;
if coeff != E::ONE {
merged_polys[eval.point()].0 = E::ONE;
multiply_poly(merged_polys[eval.point()].1.to_mut().borrow_mut(), &coeff);
multiply_poly(&mut merged_polys[eval.point()].1, &coeff);
}
// Equivalent to merged_poly += poly * batch_coeff. Note that
// add_assign_mixed_with_coeff allows adding two polynomials with
// different variables, and the result has the same number of vars
// with the larger one of the two added polynomials.
add_polynomial_with_coeff(
merged_polys[eval.point()].1.to_mut().borrow_mut(),
&mut merged_polys[eval.point()].1,
&polys[eval.poly()],
&eq_xt_i,
);
Expand All @@ -642,18 +643,16 @@ where
.iter()
.zip(&points)
.map(|((scalar, poly), point)| {
inner_product(
&poly_iter_ext(poly).collect_vec(),
build_eq_x_r_vec(point).iter(),
) * scalar
* E::from(1 << (num_vars - poly.num_vars))
inner_product(poly, build_eq_x_r_vec(point).iter())
* scalar
* E::from(1 << (num_vars - log2_strict(poly.len())))
// When this polynomial is smaller, it will be repeatedly summed over the cosets of the hypercube
})
.sum::<E>();
assert_eq!(expected_sum, target_sum);

merged_polys.iter().enumerate().for_each(|(i, (_, poly))| {
assert_eq!(points[i].len(), poly.num_vars);
assert_eq!(points[i].len(), log2_strict(poly.len()));
});
}

Expand All @@ -666,12 +665,17 @@ where
* scalar
})
.sum();
let sumcheck_polys: Vec<&DenseMultilinearExtension<E>> = merged_polys
let sumcheck_polys: Vec<DenseMultilinearExtension<E>> = merged_polys
.iter()
.map(|(_, poly)| poly.deref())
.map(|(_, poly)| {
DenseMultilinearExtension::from_evaluations_ext_vec(
log2_strict(poly.len()),
poly.clone(),
)
})
.collect_vec();
let virtual_poly =
VirtualPolynomial::new(&expression, sumcheck_polys, &[], points.as_slice());
VirtualPolynomial::new(&expression, sumcheck_polys.iter(), &[], points.as_slice());

let (challenges, merged_poly_evals, sumcheck_proof) =
SumCheck::prove(&(), num_vars, virtual_poly, target_sum, transcript)?;
Expand All @@ -695,7 +699,7 @@ where
if cfg!(feature = "sanity-check") {
let poly_evals = polys
.iter()
.map(|poly| poly.evaluate(&challenges[..poly.num_vars]))
.map(|poly| poly.evaluate(&challenges[..poly.num_vars()]))
.collect_vec();
let new_target_sum = inner_product(&poly_evals, &coeffs);
let desired_sum = merged_polys
Expand All @@ -705,7 +709,11 @@ where
.map(|(((scalar, poly), point), evals_from_sum_check)| {
assert_eq!(
evals_from_sum_check,
poly.evaluate(&challenges[..poly.num_vars])
DenseMultilinearExtension::from_evaluations_ext_vec(
log2_strict(poly.len()),
poly.clone()
)
.evaluate(&challenges[..log2_strict(poly.len())])
);
*scalar
* evals_from_sum_check
Expand Down
24 changes: 15 additions & 9 deletions mpcs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pub fn pcs_batch_commit_and_write<E: ExtensionField, Pcs: PolynomialCommitmentSc

pub fn pcs_open<E: ExtensionField, Pcs: PolynomialCommitmentScheme<E>>(
pp: &Pcs::ProverParam,
poly: &DenseMultilinearExtension<E>,
poly: &ArcMultilinearExtension<E>,
comm: &Pcs::CommitmentWithData,
point: &[E],
eval: &E,
Expand All @@ -73,7 +73,7 @@ pub fn pcs_open<E: ExtensionField, Pcs: PolynomialCommitmentScheme<E>>(

pub fn pcs_batch_open<E: ExtensionField, Pcs: PolynomialCommitmentScheme<E>>(
pp: &Pcs::ProverParam,
polys: &[DenseMultilinearExtension<E>],
polys: &[ArcMultilinearExtension<E>],
comms: &[Pcs::CommitmentWithData],
points: &[Vec<E>],
evals: &[Evaluation<E>],
Expand Down Expand Up @@ -162,7 +162,7 @@ pub trait PolynomialCommitmentScheme<E: ExtensionField>: Clone + Debug {

fn open(
pp: &Self::ProverParam,
poly: &DenseMultilinearExtension<E>,
poly: &ArcMultilinearExtension<E>,
comm: &Self::CommitmentWithData,
point: &[E],
eval: &E,
Expand All @@ -171,7 +171,7 @@ pub trait PolynomialCommitmentScheme<E: ExtensionField>: Clone + Debug {

fn batch_open(
pp: &Self::ProverParam,
polys: &[DenseMultilinearExtension<E>],
polys: &[ArcMultilinearExtension<E>],
comms: &[Self::CommitmentWithData],
points: &[Vec<E>],
evals: &[Evaluation<E>],
Expand Down Expand Up @@ -226,7 +226,7 @@ where
{
fn ni_open(
pp: &Self::ProverParam,
poly: &DenseMultilinearExtension<E>,
poly: &ArcMultilinearExtension<E>,
comm: &Self::CommitmentWithData,
point: &[E],
eval: &E,
Expand All @@ -237,7 +237,7 @@ where

fn ni_batch_open(
pp: &Self::ProverParam,
polys: &[DenseMultilinearExtension<E>],
polys: &[ArcMultilinearExtension<E>],
comms: &[Self::CommitmentWithData],
points: &[Vec<E>],
evals: &[Evaluation<E>],
Expand Down Expand Up @@ -323,17 +323,17 @@ use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension;
fn validate_input<E: ExtensionField>(
function: &str,
param_num_vars: usize,
polys: &[DenseMultilinearExtension<E>],
polys: &[ArcMultilinearExtension<E>],
points: &[Vec<E>],
) -> Result<(), Error> {
let polys = polys.iter().collect_vec();
let points = points.iter().collect_vec();
for poly in polys.iter() {
if param_num_vars < poly.num_vars {
if param_num_vars < poly.num_vars() {
return Err(err_too_many_variates(
function,
param_num_vars,
poly.num_vars,
poly.num_vars(),
));
}
}
Expand Down Expand Up @@ -462,6 +462,7 @@ pub mod test_util {
let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap();
let point = get_point_from_challenge(num_vars, &mut transcript);
let eval = poly.evaluate(point.as_slice());
let poly = ArcMultilinearExtension::from(poly);
transcript.append_field_element_ext(&eval);

(
Expand Down Expand Up @@ -533,6 +534,11 @@ pub mod test_util {
.collect::<Vec<E>>();
transcript.append_field_element_exts(values.as_slice());

let polys = polys
.iter()
.map(|poly| ArcMultilinearExtension::from(poly.clone()))
.collect_vec();

let proof =
Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap();
(comms, evals, proof, transcript.read_challenge())
Expand Down
Loading
Loading