Skip to content

Commit

Permalink
Merge branch 'master' into dependabot/cargo/blake2-0.10
Browse files Browse the repository at this point in the history
  • Loading branch information
weikengchen authored Oct 11, 2022
2 parents 2a2d5a7 + 4298076 commit 7579b60
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 44 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

### Breaking changes

- [\#55](https://github.com/arkworks-rs/sumcheck/pull/55) Change the function signatures of `IPForMLSumcheck::verify_round` and `IPForMLSumcheck::prove_round`.

### Features

### Improvements

- [\#55](https://github.com/arkworks-rs/sumcheck/pull/55) Improve the interpolation performance and avoid unnecessary state clones.

### Bug fixes

## v0.3.0
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ ark-serialize = { version = "^0.3.0", default-features = false, features = ["der
ark-std = { version = "^0.3.0", default-features = false }
ark-poly = { version = "^0.3.0", default-features = false }
blake2 = { version = "0.10", default-features = false }
hashbrown = { version = "0.11.2" }
hashbrown = { version = "0.12.3" }
rayon = { version = "1", optional = true }

[dev-dependencies]
Expand Down
13 changes: 5 additions & 8 deletions src/gkr_round_sumcheck/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ impl<F: Field> GKRRoundSumcheck<F> {
let mut phase1_prover_msgs = Vec::with_capacity(dim);
let mut u = Vec::with_capacity(dim);
for _ in 0..dim {
let (pm, ps) = IPForMLSumcheck::prove_round(phase1_ps, &phase1_vm);
phase1_ps = ps;
let pm = IPForMLSumcheck::prove_round(&mut phase1_ps, &phase1_vm);

rng.feed(&pm).unwrap();
phase1_prover_msgs.push(pm);
let vm = IPForMLSumcheck::sample_round(&mut rng);
Expand All @@ -121,8 +121,7 @@ impl<F: Field> GKRRoundSumcheck<F> {
let mut phase2_prover_msgs = Vec::with_capacity(dim);
let mut v = Vec::with_capacity(dim);
for _ in 0..dim {
let (pm, ps) = IPForMLSumcheck::prove_round(phase2_ps, &phase2_vm);
phase2_ps = ps;
let pm = IPForMLSumcheck::prove_round(&mut phase2_ps, &phase2_vm);
rng.feed(&pm).unwrap();
phase2_prover_msgs.push(pm);
let vm = IPForMLSumcheck::sample_round(&mut rng);
Expand Down Expand Up @@ -160,8 +159,7 @@ impl<F: Field> GKRRoundSumcheck<F> {
for i in 0..dim {
let pm = &proof.phase1_sumcheck_msgs[i];
rng.feed(pm).unwrap();
let result = IPForMLSumcheck::verify_round((*pm).clone(), phase1_vs, &mut rng);
phase1_vs = result.1;
let _result = IPForMLSumcheck::verify_round((*pm).clone(), &mut phase1_vs, &mut rng);
}
let phase1_subclaim = IPForMLSumcheck::check_and_generate_subclaim(phase1_vs, claimed_sum)?;
let u = phase1_subclaim.point;
Expand All @@ -173,8 +171,7 @@ impl<F: Field> GKRRoundSumcheck<F> {
for i in 0..dim {
let pm = &proof.phase2_sumcheck_msgs[i];
rng.feed(pm).unwrap();
let result = IPForMLSumcheck::verify_round((*pm).clone(), phase2_vs, &mut rng);
phase2_vs = result.1;
let _result = IPForMLSumcheck::verify_round((*pm).clone(), &mut phase2_vs, &mut rng);
}
let phase2_subclaim = IPForMLSumcheck::check_and_generate_subclaim(
phase2_vs,
Expand Down
12 changes: 6 additions & 6 deletions src/ml_sumcheck/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ impl<F: Field> MLSumcheck<F> {
let mut verifier_msg = None;
let mut prover_msgs = Vec::with_capacity(polynomial.num_variables);
for _ in 0..polynomial.num_variables {
let (prover_msg, prover_state_new) =
IPForMLSumcheck::prove_round(prover_state, &verifier_msg);
prover_state = prover_state_new;
let prover_msg = IPForMLSumcheck::prove_round(&mut prover_state, &verifier_msg);
fs_rng.feed(&prover_msg)?;
prover_msgs.push(prover_msg);
verifier_msg = Some(IPForMLSumcheck::sample_round(&mut fs_rng));
Expand All @@ -70,9 +68,11 @@ impl<F: Field> MLSumcheck<F> {
for i in 0..polynomial_info.num_variables {
let prover_msg = proof.get(i).expect("proof is incomplete");
fs_rng.feed(prover_msg)?;
let result =
IPForMLSumcheck::verify_round((*prover_msg).clone(), verifier_state, &mut fs_rng);
verifier_state = result.1;
let _verifier_msg = IPForMLSumcheck::verify_round(
(*prover_msg).clone(),
&mut verifier_state,
&mut fs_rng,
);
}

Ok(IPForMLSumcheck::check_and_generate_subclaim(
Expand Down
13 changes: 5 additions & 8 deletions src/ml_sumcheck/protocol/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ impl<F: Field> IPForMLSumcheck<F> {
///
/// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2).
pub fn prove_round(
mut prover_state: ProverState<F>,
prover_state: &mut ProverState<F>,
v_msg: &Option<VerifierMsg<F>>,
) -> (ProverMsg<F>, ProverState<F>) {
) -> ProverMsg<F> {
if let Some(msg) = v_msg {
if prover_state.round == 0 {
panic!("first round should be prover first.");
Expand Down Expand Up @@ -120,11 +120,8 @@ impl<F: Field> IPForMLSumcheck<F> {
}
}

(
ProverMsg {
evaluations: products_sum,
},
prover_state,
)
ProverMsg {
evaluations: products_sum,
}
}
}
191 changes: 175 additions & 16 deletions src/ml_sumcheck/protocol/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ impl<F: Field> IPForMLSumcheck<F> {
/// the last step.
pub fn verify_round<R: RngCore>(
prover_msg: ProverMsg<F>,
mut verifier_state: VerifierState<F>,
verifier_state: &mut VerifierState<F>,
rng: &mut R,
) -> (Option<VerifierMsg<F>>, VerifierState<F>) {
) -> Option<VerifierMsg<F>> {
if verifier_state.finished {
panic!("Incorrect verifier state: Verifier is already finished.");
}
Expand All @@ -79,7 +79,7 @@ impl<F: Field> IPForMLSumcheck<F> {
} else {
verifier_state.round += 1;
}
(Some(msg), verifier_state)
Some(msg)
}

/// verify the sumcheck phase, and generate the subclaim
Expand Down Expand Up @@ -132,22 +132,181 @@ impl<F: Field> IPForMLSumcheck<F> {
}
}

/// interpolate a uni-variate degree-`p_i.len()-1` polynomial and evaluate this polynomial at `eval_at`.
/// interpolate a uni-variate degree-`p_i.len()-1` polynomial and evaluate this
/// polynomial at `eval_at`:
/// \sum_{i=0}^len p_i * (\prod_{j!=i} (eval_at - j)/(i-j))
pub(crate) fn interpolate_uni_poly<F: Field>(p_i: &[F], eval_at: F) -> F {
let mut result = F::zero();
let mut i = F::zero();
for term in p_i.iter() {
let mut term = *term;
let mut j = F::zero();
for _ in 0..p_i.len() {
if j != i {
term = term * (eval_at - j) / (i - j)
let len = p_i.len();

let mut evals = vec![];

let mut prod = eval_at;
evals.push(eval_at);

// `prod = \prod_{j} (eval_at - j)`
for e in 1..len {
let tmp = eval_at - F::from(e as u64);
evals.push(tmp);
prod *= tmp;
}
let mut res = F::zero();
// we want to compute \prod (j!=i) (i-j) for a given i
//
// we start from the last step, which is
// denom[len-1] = (len-1) * (len-2) *... * 2 * 1
// the step before that is
// denom[len-2] = (len-2) * (len-3) * ... * 2 * 1 * -1
// and the step before that is
// denom[len-3] = (len-3) * (len-4) * ... * 2 * 1 * -1 * -2
//
// i.e., for any i, the one before this will be derived from
// denom[i-1] = - denom[i] * (len-i) / i
//
// that is, we only need to store
// - the last denom for i = len-1, and
// - the ratio between the current step and the last step, which is the
// product of -(len-i) / i from all previous steps and we store
// this product as a fraction number to reduce field divisions.

// We know
// - 2^61 < factorial(20) < 2^62
// - 2^122 < factorial(33) < 2^123
// so we will be able to compute the ratio
// - for len <= 20 with i64
// - for len <= 33 with i128
// - for len > 33 with BigInt
if p_i.len() <= 20 {
let last_denom = F::from(u64_factorial(len - 1));
let mut ratio_numerator = 1i64;
let mut ratio_enumerator = 1u64;

for i in (0..len).rev() {
let ratio_numerator_f = if ratio_numerator < 0 {
-F::from((-ratio_numerator) as u64)
} else {
F::from(ratio_numerator as u64)
};

res += p_i[i] * prod * F::from(ratio_enumerator)
/ (last_denom * ratio_numerator_f * evals[i]);

// compute ratio for the next step which is current_ratio * -(len-i)/i
if i != 0 {
ratio_numerator *= -(len as i64 - i as i64);
ratio_enumerator *= i as u64;
}
}
} else if p_i.len() <= 33 {
let last_denom = F::from(u128_factorial(len - 1));
let mut ratio_numerator = 1i128;
let mut ratio_enumerator = 1u128;

for i in (0..len).rev() {
let ratio_numerator_f = if ratio_numerator < 0 {
-F::from((-ratio_numerator) as u128)
} else {
F::from(ratio_numerator as u128)
};

res += p_i[i] * prod * F::from(ratio_enumerator)
/ (last_denom * ratio_numerator_f * evals[i]);

// compute ratio for the next step which is current_ratio * -(len-i)/i
if i != 0 {
ratio_numerator *= -(len as i128 - i as i128);
ratio_enumerator *= i as u128;
}
}
} else {
// since we are using field operations, we can merge
// `last_denom` and `ratio_numerator` into a single field element.
let mut denom_up = field_factorial::<F>(len - 1);
let mut denom_down = F::one();

for i in (0..len).rev() {
res += p_i[i] * prod * denom_down / (denom_up * evals[i]);

// compute denom for the next step is -current_denom * (len-i)/i
if i != 0 {
denom_up *= -F::from((len - i) as u64);
denom_down *= F::from(i as u64);
}
j += F::one();
}
i += F::one();
result += term;
}

result
res
}

/// compute the factorial(a) = 1 * 2 * ... * a
#[inline]
fn field_factorial<F: Field>(a: usize) -> F {
let mut res = F::one();
for i in 1..=a {
res *= F::from(i as u64);
}
res
}

/// compute the factorial(a) = 1 * 2 * ... * a
#[inline]
fn u128_factorial(a: usize) -> u128 {
let mut res = 1u128;
for i in 1..=a {
res *= i as u128;
}
res
}

/// compute the factorial(a) = 1 * 2 * ... * a
#[inline]
fn u64_factorial(a: usize) -> u64 {
let mut res = 1u64;
for i in 1..=a {
res *= i as u64;
}
res
}

#[cfg(test)]
mod test {
use crate::ml_sumcheck::protocol::verifier::interpolate_uni_poly;
use ark_poly::univariate::DensePolynomial;
use ark_poly::Polynomial;
use ark_poly::UVPolynomial;
use ark_std::vec::Vec;
use ark_std::UniformRand;

type F = ark_test_curves::bls12_381::Fr;

#[test]
fn test_interpolation() {
let mut prng = ark_std::test_rng();

// test a polynomial with 20 known points, i.e., with degree 19
let poly = DensePolynomial::<F>::rand(20 - 1, &mut prng);
let evals = (0..20)
.map(|i| poly.evaluate(&F::from(i)))
.collect::<Vec<F>>();
let query = F::rand(&mut prng);

assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query));

// test a polynomial with 33 known points, i.e., with degree 32
let poly = DensePolynomial::<F>::rand(33 - 1, &mut prng);
let evals = (0..33)
.map(|i| poly.evaluate(&F::from(i)))
.collect::<Vec<F>>();
let query = F::rand(&mut prng);

assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query));

// test a polynomial with 64 known points, i.e., with degree 63
let poly = DensePolynomial::<F>::rand(64 - 1, &mut prng);
let evals = (0..64)
.map(|i| poly.evaluate(&F::from(i)))
.collect::<Vec<F>>();
let query = F::rand(&mut prng);

assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query));
}
}
8 changes: 3 additions & 5 deletions src/ml_sumcheck/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,10 @@ fn test_protocol(nv: usize, num_multiplicands_range: (usize, usize), num_product
let mut verifier_state = IPForMLSumcheck::verifier_init(&poly_info);
let mut verifier_msg = None;
for _ in 0..poly.num_variables {
let result = IPForMLSumcheck::prove_round(prover_state, &verifier_msg);
prover_state = result.1;
let (verifier_msg2, verifier_state2) =
IPForMLSumcheck::verify_round(result.0, verifier_state, &mut rng);
let prover_message = IPForMLSumcheck::prove_round(&mut prover_state, &verifier_msg);
let verifier_msg2 =
IPForMLSumcheck::verify_round(prover_message, &mut verifier_state, &mut rng);
verifier_msg = verifier_msg2;
verifier_state = verifier_state2;
}
let subclaim = IPForMLSumcheck::check_and_generate_subclaim(verifier_state, asserted_sum)
.expect("fail to generate subclaim");
Expand Down

0 comments on commit 7579b60

Please sign in to comment.