diff --git a/CHANGELOG.md b/CHANGELOG.md index e50fd30..5078304 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,10 +4,14 @@ ### Breaking changes +- [\#55](https://github.com/arkworks-rs/sumcheck/pull/55) Change the function signatures of `IPForMLSumcheck::verify_round` and `IPForMLSumcheck::prove_round`. + ### Features ### Improvements +- [\#55](https://github.com/arkworks-rs/sumcheck/pull/55) Improve the interpolation performance and avoid unnecessary state clones. + ### Bug fixes ## v0.3.0 diff --git a/Cargo.toml b/Cargo.toml index 6bee565..60b72e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ ark-serialize = { version = "^0.3.0", default-features = false, features = ["der ark-std = { version = "^0.3.0", default-features = false } ark-poly = { version = "^0.3.0", default-features = false } blake2 = { version = "0.10", default-features = false } -hashbrown = { version = "0.11.2" } +hashbrown = { version = "0.12.3" } rayon = { version = "1", optional = true } [dev-dependencies] diff --git a/src/gkr_round_sumcheck/mod.rs b/src/gkr_round_sumcheck/mod.rs index 4d4d3da..cc7c2ba 100644 --- a/src/gkr_round_sumcheck/mod.rs +++ b/src/gkr_round_sumcheck/mod.rs @@ -106,8 +106,8 @@ impl GKRRoundSumcheck { let mut phase1_prover_msgs = Vec::with_capacity(dim); let mut u = Vec::with_capacity(dim); for _ in 0..dim { - let (pm, ps) = IPForMLSumcheck::prove_round(phase1_ps, &phase1_vm); - phase1_ps = ps; + let pm = IPForMLSumcheck::prove_round(&mut phase1_ps, &phase1_vm); + rng.feed(&pm).unwrap(); phase1_prover_msgs.push(pm); let vm = IPForMLSumcheck::sample_round(&mut rng); @@ -121,8 +121,7 @@ impl GKRRoundSumcheck { let mut phase2_prover_msgs = Vec::with_capacity(dim); let mut v = Vec::with_capacity(dim); for _ in 0..dim { - let (pm, ps) = IPForMLSumcheck::prove_round(phase2_ps, &phase2_vm); - phase2_ps = ps; + let pm = IPForMLSumcheck::prove_round(&mut phase2_ps, &phase2_vm); rng.feed(&pm).unwrap(); phase2_prover_msgs.push(pm); let vm = IPForMLSumcheck::sample_round(&mut rng); @@ -160,8 +159,7 @@ impl GKRRoundSumcheck { for i in 0..dim { let pm = &proof.phase1_sumcheck_msgs[i]; rng.feed(pm).unwrap(); - let result = IPForMLSumcheck::verify_round((*pm).clone(), phase1_vs, &mut rng); - phase1_vs = result.1; + let _result = IPForMLSumcheck::verify_round((*pm).clone(), &mut phase1_vs, &mut rng); } let phase1_subclaim = IPForMLSumcheck::check_and_generate_subclaim(phase1_vs, claimed_sum)?; let u = phase1_subclaim.point; @@ -173,8 +171,7 @@ impl GKRRoundSumcheck { for i in 0..dim { let pm = &proof.phase2_sumcheck_msgs[i]; rng.feed(pm).unwrap(); - let result = IPForMLSumcheck::verify_round((*pm).clone(), phase2_vs, &mut rng); - phase2_vs = result.1; + let _result = IPForMLSumcheck::verify_round((*pm).clone(), &mut phase2_vs, &mut rng); } let phase2_subclaim = IPForMLSumcheck::check_and_generate_subclaim( phase2_vs, diff --git a/src/ml_sumcheck/mod.rs b/src/ml_sumcheck/mod.rs index 143beb7..6760039 100644 --- a/src/ml_sumcheck/mod.rs +++ b/src/ml_sumcheck/mod.rs @@ -47,9 +47,7 @@ impl MLSumcheck { let mut verifier_msg = None; let mut prover_msgs = Vec::with_capacity(polynomial.num_variables); for _ in 0..polynomial.num_variables { - let (prover_msg, prover_state_new) = - IPForMLSumcheck::prove_round(prover_state, &verifier_msg); - prover_state = prover_state_new; + let prover_msg = IPForMLSumcheck::prove_round(&mut prover_state, &verifier_msg); fs_rng.feed(&prover_msg)?; prover_msgs.push(prover_msg); verifier_msg = Some(IPForMLSumcheck::sample_round(&mut fs_rng)); @@ -70,9 +68,11 @@ impl MLSumcheck { for i in 0..polynomial_info.num_variables { let prover_msg = proof.get(i).expect("proof is incomplete"); fs_rng.feed(prover_msg)?; - let result = - IPForMLSumcheck::verify_round((*prover_msg).clone(), verifier_state, &mut fs_rng); - verifier_state = result.1; + let _verifier_msg = IPForMLSumcheck::verify_round( + (*prover_msg).clone(), + &mut verifier_state, + &mut fs_rng, + ); } Ok(IPForMLSumcheck::check_and_generate_subclaim( diff --git a/src/ml_sumcheck/protocol/prover.rs b/src/ml_sumcheck/protocol/prover.rs index 2150d8a..cbca997 100644 --- a/src/ml_sumcheck/protocol/prover.rs +++ b/src/ml_sumcheck/protocol/prover.rs @@ -67,9 +67,9 @@ impl IPForMLSumcheck { /// /// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2). pub fn prove_round( - mut prover_state: ProverState, + prover_state: &mut ProverState, v_msg: &Option>, - ) -> (ProverMsg, ProverState) { + ) -> ProverMsg { if let Some(msg) = v_msg { if prover_state.round == 0 { panic!("first round should be prover first."); @@ -120,11 +120,8 @@ impl IPForMLSumcheck { } } - ( - ProverMsg { - evaluations: products_sum, - }, - prover_state, - ) + ProverMsg { + evaluations: products_sum, + } } } diff --git a/src/ml_sumcheck/protocol/verifier.rs b/src/ml_sumcheck/protocol/verifier.rs index 185edf3..9f7790c 100644 --- a/src/ml_sumcheck/protocol/verifier.rs +++ b/src/ml_sumcheck/protocol/verifier.rs @@ -53,9 +53,9 @@ impl IPForMLSumcheck { /// the last step. pub fn verify_round( prover_msg: ProverMsg, - mut verifier_state: VerifierState, + verifier_state: &mut VerifierState, rng: &mut R, - ) -> (Option>, VerifierState) { + ) -> Option> { if verifier_state.finished { panic!("Incorrect verifier state: Verifier is already finished."); } @@ -79,7 +79,7 @@ impl IPForMLSumcheck { } else { verifier_state.round += 1; } - (Some(msg), verifier_state) + Some(msg) } /// verify the sumcheck phase, and generate the subclaim @@ -132,22 +132,181 @@ impl IPForMLSumcheck { } } -/// interpolate a uni-variate degree-`p_i.len()-1` polynomial and evaluate this polynomial at `eval_at`. +/// interpolate a uni-variate degree-`p_i.len()-1` polynomial and evaluate this +/// polynomial at `eval_at`: +/// \sum_{i=0}^len p_i * (\prod_{j!=i} (eval_at - j)/(i-j)) pub(crate) fn interpolate_uni_poly(p_i: &[F], eval_at: F) -> F { - let mut result = F::zero(); - let mut i = F::zero(); - for term in p_i.iter() { - let mut term = *term; - let mut j = F::zero(); - for _ in 0..p_i.len() { - if j != i { - term = term * (eval_at - j) / (i - j) + let len = p_i.len(); + + let mut evals = vec![]; + + let mut prod = eval_at; + evals.push(eval_at); + + // `prod = \prod_{j} (eval_at - j)` + for e in 1..len { + let tmp = eval_at - F::from(e as u64); + evals.push(tmp); + prod *= tmp; + } + let mut res = F::zero(); + // we want to compute \prod (j!=i) (i-j) for a given i + // + // we start from the last step, which is + // denom[len-1] = (len-1) * (len-2) *... * 2 * 1 + // the step before that is + // denom[len-2] = (len-2) * (len-3) * ... * 2 * 1 * -1 + // and the step before that is + // denom[len-3] = (len-3) * (len-4) * ... * 2 * 1 * -1 * -2 + // + // i.e., for any i, the one before this will be derived from + // denom[i-1] = - denom[i] * (len-i) / i + // + // that is, we only need to store + // - the last denom for i = len-1, and + // - the ratio between the current step and the last step, which is the + // product of -(len-i) / i from all previous steps and we store + // this product as a fraction number to reduce field divisions. + + // We know + // - 2^61 < factorial(20) < 2^62 + // - 2^122 < factorial(33) < 2^123 + // so we will be able to compute the ratio + // - for len <= 20 with i64 + // - for len <= 33 with i128 + // - for len > 33 with BigInt + if p_i.len() <= 20 { + let last_denom = F::from(u64_factorial(len - 1)); + let mut ratio_numerator = 1i64; + let mut ratio_enumerator = 1u64; + + for i in (0..len).rev() { + let ratio_numerator_f = if ratio_numerator < 0 { + -F::from((-ratio_numerator) as u64) + } else { + F::from(ratio_numerator as u64) + }; + + res += p_i[i] * prod * F::from(ratio_enumerator) + / (last_denom * ratio_numerator_f * evals[i]); + + // compute ratio for the next step which is current_ratio * -(len-i)/i + if i != 0 { + ratio_numerator *= -(len as i64 - i as i64); + ratio_enumerator *= i as u64; + } + } + } else if p_i.len() <= 33 { + let last_denom = F::from(u128_factorial(len - 1)); + let mut ratio_numerator = 1i128; + let mut ratio_enumerator = 1u128; + + for i in (0..len).rev() { + let ratio_numerator_f = if ratio_numerator < 0 { + -F::from((-ratio_numerator) as u128) + } else { + F::from(ratio_numerator as u128) + }; + + res += p_i[i] * prod * F::from(ratio_enumerator) + / (last_denom * ratio_numerator_f * evals[i]); + + // compute ratio for the next step which is current_ratio * -(len-i)/i + if i != 0 { + ratio_numerator *= -(len as i128 - i as i128); + ratio_enumerator *= i as u128; + } + } + } else { + // since we are using field operations, we can merge + // `last_denom` and `ratio_numerator` into a single field element. + let mut denom_up = field_factorial::(len - 1); + let mut denom_down = F::one(); + + for i in (0..len).rev() { + res += p_i[i] * prod * denom_down / (denom_up * evals[i]); + + // compute denom for the next step is -current_denom * (len-i)/i + if i != 0 { + denom_up *= -F::from((len - i) as u64); + denom_down *= F::from(i as u64); } - j += F::one(); } - i += F::one(); - result += term; } - result + res +} + +/// compute the factorial(a) = 1 * 2 * ... * a +#[inline] +fn field_factorial(a: usize) -> F { + let mut res = F::one(); + for i in 1..=a { + res *= F::from(i as u64); + } + res +} + +/// compute the factorial(a) = 1 * 2 * ... * a +#[inline] +fn u128_factorial(a: usize) -> u128 { + let mut res = 1u128; + for i in 1..=a { + res *= i as u128; + } + res +} + +/// compute the factorial(a) = 1 * 2 * ... * a +#[inline] +fn u64_factorial(a: usize) -> u64 { + let mut res = 1u64; + for i in 1..=a { + res *= i as u64; + } + res +} + +#[cfg(test)] +mod test { + use crate::ml_sumcheck::protocol::verifier::interpolate_uni_poly; + use ark_poly::univariate::DensePolynomial; + use ark_poly::Polynomial; + use ark_poly::UVPolynomial; + use ark_std::vec::Vec; + use ark_std::UniformRand; + + type F = ark_test_curves::bls12_381::Fr; + + #[test] + fn test_interpolation() { + let mut prng = ark_std::test_rng(); + + // test a polynomial with 20 known points, i.e., with degree 19 + let poly = DensePolynomial::::rand(20 - 1, &mut prng); + let evals = (0..20) + .map(|i| poly.evaluate(&F::from(i))) + .collect::>(); + let query = F::rand(&mut prng); + + assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query)); + + // test a polynomial with 33 known points, i.e., with degree 32 + let poly = DensePolynomial::::rand(33 - 1, &mut prng); + let evals = (0..33) + .map(|i| poly.evaluate(&F::from(i))) + .collect::>(); + let query = F::rand(&mut prng); + + assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query)); + + // test a polynomial with 64 known points, i.e., with degree 63 + let poly = DensePolynomial::::rand(64 - 1, &mut prng); + let evals = (0..64) + .map(|i| poly.evaluate(&F::from(i))) + .collect::>(); + let query = F::rand(&mut prng); + + assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query)); + } } diff --git a/src/ml_sumcheck/test.rs b/src/ml_sumcheck/test.rs index 1f71a3d..a59ac8b 100644 --- a/src/ml_sumcheck/test.rs +++ b/src/ml_sumcheck/test.rs @@ -81,12 +81,10 @@ fn test_protocol(nv: usize, num_multiplicands_range: (usize, usize), num_product let mut verifier_state = IPForMLSumcheck::verifier_init(&poly_info); let mut verifier_msg = None; for _ in 0..poly.num_variables { - let result = IPForMLSumcheck::prove_round(prover_state, &verifier_msg); - prover_state = result.1; - let (verifier_msg2, verifier_state2) = - IPForMLSumcheck::verify_round(result.0, verifier_state, &mut rng); + let prover_message = IPForMLSumcheck::prove_round(&mut prover_state, &verifier_msg); + let verifier_msg2 = + IPForMLSumcheck::verify_round(prover_message, &mut verifier_state, &mut rng); verifier_msg = verifier_msg2; - verifier_state = verifier_state2; } let subclaim = IPForMLSumcheck::check_and_generate_subclaim(verifier_state, asserted_sum) .expect("fail to generate subclaim");