From ee2914e9eb9f753bdd4d2f9ae5059b93926e41b7 Mon Sep 17 00:00:00 2001 From: Michael Zhu Date: Thu, 6 Feb 2025 11:05:24 -0500 Subject: [PATCH] Combined read/write-checking sumcheck --- jolt-core/src/subprotocols/twist.rs | 423 +++++++++++++++++++++------- 1 file changed, 322 insertions(+), 101 deletions(-) diff --git a/jolt-core/src/subprotocols/twist.rs b/jolt-core/src/subprotocols/twist.rs index ed8835bee..f99dced70 100644 --- a/jolt-core/src/subprotocols/twist.rs +++ b/jolt-core/src/subprotocols/twist.rs @@ -31,36 +31,80 @@ pub enum TwistAlgorithm { Alternative, } -pub fn prove_read_checking( - increments: Vec<(usize, i64)>, - rv: MultilinearPolynomial, +pub struct TwistProof { + /// Joint sumcheck proof for the read-checking and write-checking sumchecks + /// (steps 3 and 4 of Figure 9). + read_write_checking_sumcheck: SumcheckInstanceProof, + /// The claimed evaluation ra(r_address, r_cycle) output by the read/write- + /// checking sumcheck. + ra_claim: F, + /// The claimed evaluation rv(r') proven by the read-checking sumcheck. + rv_claim: F, + /// The claimed evaluation wa(r_address, r_cycle) output by the read/write- + /// checking sumcheck. + wa_claim: F, + /// The claimed evaluation wv(r_address, r_cycle) output by the read/write- + /// checking sumcheck. + wv_claim: F, + /// The claimed evaluation val(r_address, r_cycle) output by the read/write- + /// checking sumcheck. + val_claim: F, + /// The claimed evaluation Inc(r, r') proven by the write-checking sumcheck. + inc_claim: F, +} + +pub fn prove_read_write_checking( + read_addresses: Vec, + read_values: Vec, + write_addresses: Vec, + write_values: Vec, + write_increments: Vec, + r: Vec, r_prime: Vec, - K: usize, transcript: &mut ProofTranscript, algorithm: TwistAlgorithm, -) -> (SumcheckInstanceProof, Vec, F, F) { - let rv_eval = rv.evaluate(&r_prime); +) -> TwistProof { match algorithm { - TwistAlgorithm::Local => { - prove_read_checking_local(increments, &r_prime, K, rv_eval, transcript) - } + TwistAlgorithm::Local => prove_read_write_checking_local( + read_addresses, + read_values, + write_addresses, + write_values, + write_increments, + &r, + &r_prime, + transcript, + ), TwistAlgorithm::Alternative => unimplemented!(), } } -fn prove_read_checking_local( - increments: Vec<(usize, i64)>, +fn prove_read_write_checking_local( + read_addresses: Vec, + read_values: Vec, + write_addresses: Vec, + write_values: Vec, + write_increments: Vec, + r: &[F], r_prime: &[F], - K: usize, - claimed_evaluation: F, transcript: &mut ProofTranscript, -) -> (SumcheckInstanceProof, Vec, F, F) { +) -> TwistProof { const DEGREE: usize = 3; + let K = r.len().pow2(); let T = r_prime.len().pow2(); - debug_assert_eq!(increments.len(), T); + + debug_assert_eq!(read_addresses.len(), T); + debug_assert_eq!(read_values.len(), T); + debug_assert_eq!(write_addresses.len(), T); + debug_assert_eq!(write_values.len(), T); + debug_assert_eq!(write_increments.len(), T); + + // Used to batch the read-checking and write-checking sumcheck + // (see Section 4.2.1) + let z: F = transcript.challenge_scalar(); let num_rounds = K.log_2() + T.log_2(); - let mut r: Vec = Vec::with_capacity(num_rounds); + let mut r_sumcheck: Vec = Vec::with_capacity(num_rounds); let num_chunks = rayon::current_num_threads().next_power_of_two().min(T); let chunk_size = T / num_chunks; @@ -69,13 +113,13 @@ fn prove_read_checking_local( let mut val_test = { // Compute Val in cycle-major order, since we will be binding // from low-to-high starting with the cycle variables - let mut val: Vec = vec![0; K * T]; + let mut val: Vec = vec![0; K * T]; val.par_chunks_mut(T).enumerate().for_each(|(k, val_k)| { let mut current_val = 0; for j in 0..T { val_k[j] = current_val; - if increments[j].0 == k { - current_val += increments[j].1; + if write_addresses[j] == k { + current_val = write_values[j]; } } }); @@ -88,19 +132,34 @@ fn prove_read_checking_local( let mut ra: Vec = unsafe_allocate_zero_vec(K * T); ra.par_chunks_mut(T).enumerate().for_each(|(k, ra_k)| { for j in 0..T { - if increments[j].0 == k { + if read_addresses[j] == k { ra_k[j] = F::one(); } } }); MultilinearPolynomial::from(ra) }; + #[cfg(test)] + let mut wa_test = { + // Compute wa in cycle-major order, since we will be binding + // from low-to-high starting with the cycle variables + let mut wa: Vec = unsafe_allocate_zero_vec(K * T); + wa.par_chunks_mut(T).enumerate().for_each(|(k, wa_k)| { + for j in 0..T { + if write_addresses[j] == k { + wa_k[j] = F::one(); + } + } + }); + MultilinearPolynomial::from(wa) + }; - let deltas: Vec> = increments[..T - chunk_size] + let deltas: Vec> = write_addresses[..T - chunk_size] .par_chunks_exact(chunk_size) - .map(|chunk| { + .zip(write_increments[..T - chunk_size].par_chunks_exact(chunk_size)) + .map(|(address_chunk, increment_chunk)| { let mut delta = vec![0i64; K]; - for (k, increment) in chunk { + for (k, increment) in address_chunk.iter().zip(increment_chunk.iter()) { delta[*k] += increment; } delta @@ -120,6 +179,7 @@ fn prove_read_checking_local( .collect(); checkpoints.push(next_checkpoint); } + // TODO(moodlezoup): could potentially generate these checkpoints in the tracer let checkpoints: Vec> = checkpoints .into_par_iter() .map(|checkpoint| checkpoint.into_iter().map(|val| F::from_i64(val)).collect()) @@ -136,18 +196,25 @@ fn prove_read_checking_local( } } + // A table that, in round i of sumcheck, stores all evaluations + // EQ(x, r_i, ..., r_1) + // as x ranges over {0, 1}^i. + // (As described in "Computing other necessary arrays and worst-case + // accounting", Section 8.2.2) let mut A: Vec = unsafe_allocate_zero_vec(chunk_size); A[0] = F::one(); // Data structure described in Equation (72) - let mut I: Vec> = increments + let mut I: Vec> = write_addresses .par_chunks(chunk_size) + .zip(write_increments.par_chunks(chunk_size)) .enumerate() - .map(|(chunk_index, increments_chunk)| { + .map(|(chunk_index, (address_chunk, increment_chunk))| { // Row index of the I matrix let mut j = chunk_index * chunk_size; - let I_chunk = increments_chunk + let I_chunk = address_chunk .iter() + .zip(increment_chunk.iter()) .map(|(k, increment)| { let inc = (j, *k, F::zero(), F::from_i64(*increment)); j += 1; @@ -158,8 +225,28 @@ fn prove_read_checking_local( }) .collect(); - let mut eq = MultilinearPolynomial::from(EqPolynomial::evals(r_prime)); - let mut previous_claim = claimed_evaluation; + let rv = MultilinearPolynomial::from(read_values); + let mut wv = MultilinearPolynomial::from(write_values); + + // eq(r, k) + let mut eq_r = MultilinearPolynomial::from(EqPolynomial::evals(r)); + // eq(r', j) + let mut eq_r_prime = MultilinearPolynomial::from(EqPolynomial::evals(r_prime)); + + // rv(r') + let rv_eval = rv.evaluate(r_prime); + // Inc(r, r') + let inc_eval: F = write_addresses + .par_iter() + .zip(write_increments.par_iter()) + .enumerate() + .map(|(cycle, (address, increment))| { + eq_r.get_coeff(*address) * eq_r_prime.get_coeff(cycle) * F::from_i64(*increment) + }) + .sum(); + // Linear combination of the read-checking claim (which is rv(r')) and the + // write-checking claim (which is Inc(r, r')) + let mut previous_claim = rv_eval + z * inc_eval; let mut compressed_polys: Vec> = Vec::with_capacity(num_rounds); // First log(T / num_chunks) rounds of sumcheck @@ -171,9 +258,15 @@ fn prove_read_checking_local( let mut inner_sum = F::zero(); for k in 0..K { let kj = k * (T >> round) + j; + // read-checking sumcheck inner_sum += ra_test.get_bound_coeff(kj) * val_test.get_bound_coeff(kj); + // write-checking sumcheck + inner_sum += z + * eq_r.get_bound_coeff(k) + * wa_test.get_bound_coeff(kj) + * (wv.get_bound_coeff(j) - val_test.get_bound_coeff(kj)) } - expected_claim += eq.get_bound_coeff(j) * inner_sum; + expected_claim += eq_r_prime.get_bound_coeff(j) * inner_sum; } assert_eq!( expected_claim, previous_claim, @@ -206,6 +299,14 @@ fn prove_read_checking_local( // where j'' are the higher (log(T) - i - 1) bits of j' let mut ra: [Vec; 2] = [unsafe_allocate_zero_vec(K), unsafe_allocate_zero_vec(K)]; + // `wa[0]` will contain + // wa(k, j'', 0, r_i, ..., r_1) + // `wa[1]` will contain + // wa(k, j'', 1, r_i, ..., r_1) + // as we iterate over rows j' \in {0, 1}^(log(T) - i), + // where j'' are the higher (log(T) - i - 1) bits of j' + let mut wa: [Vec; 2] = + [unsafe_allocate_zero_vec(K), unsafe_allocate_zero_vec(K)]; // Iterate over I_chunk, two rows at a time. I_chunk @@ -236,22 +337,28 @@ fn prove_read_checking_local( } ra[0].iter_mut().for_each(|ra_k| *ra_k = F::zero()); + wa[0].iter_mut().for_each(|wa_k| *wa_k = F::zero()); for j in j_prime << round..(j_prime + 1) << round { - let (k, _) = increments[j]; let j_bound = j % (1 << round); + let k = read_addresses[j]; ra[0][k] += A[j_bound]; + let k = write_addresses[j]; + wa[0][k] += A[j_bound]; } ra[1].iter_mut().for_each(|ra_k| *ra_k = F::zero()); + wa[1].iter_mut().for_each(|wa_k| *wa_k = F::zero()); for j in (j_prime + 1) << round..(j_prime + 2) << round { - let (k, _) = increments[j]; let j_bound = j % (1 << round); + let k = read_addresses[j]; ra[1][k] += A[j_bound]; + let k = write_addresses[j]; + wa[1][k] += A[j_bound]; } #[cfg(test)] { - // Check val for k in 0..K { + // Check val assert_eq!( val_test.get_bound_coeff(k * (T >> round) + j_prime), val_j_r[0][k], @@ -260,9 +367,7 @@ fn prove_read_checking_local( val_test.get_bound_coeff(k * (T >> round) + j_prime + 1), val_j_r[1][k], ); - } - // Check ra - for k in 0..K { + // Check ra assert_eq!( ra_test.get_bound_coeff(k * (T >> round) + j_prime), ra[0][k] @@ -271,32 +376,55 @@ fn prove_read_checking_local( ra_test.get_bound_coeff(k * (T >> round) + j_prime + 1), ra[1][k] ); + // Check wa + assert_eq!( + wa_test.get_bound_coeff(k * (T >> round) + j_prime), + wa[0][k] + ); + assert_eq!( + wa_test.get_bound_coeff(k * (T >> round) + j_prime + 1), + wa[1][k] + ); } } + let eq_r_prime_evals = + eq_r_prime.sumcheck_evals(j_prime / 2, DEGREE, BindingOrder::LowToHigh); + let wv_evals = + wv.sumcheck_evals(j_prime / 2, DEGREE, BindingOrder::LowToHigh); + let mut inner_sum_evals = [F::zero(); 3]; for k in 0..K { let m_ra = ra[1][k] - ra[0][k]; let ra_eval_2 = ra[1][k] + m_ra; let ra_eval_3 = ra_eval_2 + m_ra; + let m_wa = wa[1][k] - wa[0][k]; + let wa_eval_2 = wa[1][k] + m_wa; + let wa_eval_3 = wa_eval_2 + m_wa; + let m_val = val_j_r[1][k] - val_j_r[0][k]; let val_eval_2 = val_j_r[1][k] + m_val; let val_eval_3 = val_eval_2 + m_val; + // Read-checking sumcheck inner_sum_evals[0] += ra[0][k].mul_0_optimized(val_j_r[0][k]); inner_sum_evals[1] += ra_eval_2.mul_0_optimized(val_eval_2); inner_sum_evals[2] += ra_eval_3.mul_0_optimized(val_eval_3); - } - let mut eq_evals = [eq.get_bound_coeff(j_prime), F::zero(), F::zero()]; - let m = eq.get_bound_coeff(j_prime + 1) - eq_evals[0]; - eq_evals[1] = eq.get_bound_coeff(j_prime + 1) + m; - eq_evals[2] = eq_evals[1] + m; + let z_eq_r = z * eq_r.get_coeff(k); + // Write-checking sumcheck + inner_sum_evals[0] += + z_eq_r * wa[0][k].mul_0_optimized(wv_evals[0] - val_j_r[0][k]); + inner_sum_evals[1] += + z_eq_r * wa_eval_2.mul_0_optimized(wv_evals[1] - val_eval_2); + inner_sum_evals[2] += + z_eq_r * wa_eval_3.mul_0_optimized(wv_evals[2] - val_eval_3); + } - evals[0] += eq_evals[0] * inner_sum_evals[0]; - evals[1] += eq_evals[1] * inner_sum_evals[1]; - evals[2] += eq_evals[2] * inner_sum_evals[2]; + evals[0] += eq_r_prime_evals[0] * inner_sum_evals[0]; + evals[1] += eq_r_prime_evals[1] * inner_sum_evals[1]; + evals[2] += eq_r_prime_evals[2] * inner_sum_evals[2]; }); evals @@ -324,7 +452,7 @@ fn prove_read_checking_local( compressed_polys.push(compressed_poly); let r_j = transcript.challenge_scalar::(); - r.push(r_j); + r_sumcheck.push(r_j); previous_claim = univariate_poly.evaluate(&r_j); @@ -359,12 +487,16 @@ fn prove_read_checking_local( I_chunk.truncate(next_bound_index); }); - eq.bind_parallel(r_j, BindingOrder::LowToHigh); + rayon::join( + || wv.bind_parallel(r_j, BindingOrder::LowToHigh), + || eq_r_prime.bind_parallel(r_j, BindingOrder::LowToHigh), + ); #[cfg(test)] { val_test.bind_parallel(r_j, BindingOrder::LowToHigh); ra_test.bind_parallel(r_j, BindingOrder::LowToHigh); + wa_test.bind_parallel(r_j, BindingOrder::LowToHigh); // Check that row indices of I are non-decreasing let mut current_row = 0; @@ -390,17 +522,17 @@ fn prove_read_checking_local( } // At this point I has been bound to a point where each chunk contains a single row, - // so we might as well materialize the full `ra` and `Val` polynomials and perform + // so we might as well materialize the full `ra`, `wa`, and `Val` polynomials and perform // standard sumcheck directly using those polynomials. - // TODO(moodlezoup): Generate ra and Val in address-major order and bind variables + // TODO(moodlezoup): Generate these polynomials in address-major order and bind variables // from high-to-low for remaining rounds? let mut ra: Vec = unsafe_allocate_zero_vec(K * num_chunks); ra.par_chunks_mut(num_chunks) .enumerate() .for_each(|(k, ra_chunk)| { - for (j, increment) in increments.iter().enumerate() { - if increment.0 == k { + for (j, address) in read_addresses.iter().enumerate() { + if *address == k { let j_unbound = j / chunk_size; let j_bound = j % chunk_size; ra_chunk[j_unbound] += A[j_bound]; @@ -409,6 +541,20 @@ fn prove_read_checking_local( }); let mut ra = MultilinearPolynomial::from(ra); + let mut wa: Vec = unsafe_allocate_zero_vec(K * num_chunks); + wa.par_chunks_mut(num_chunks) + .enumerate() + .for_each(|(k, wa_chunk)| { + for (j, address) in write_addresses.iter().enumerate() { + if *address == k { + let j_unbound = j / chunk_size; + let j_bound = j % chunk_size; + wa_chunk[j_unbound] += A[j_bound]; + } + } + }); + let mut wa = MultilinearPolynomial::from(wa); + let mut val: Vec = unsafe_allocate_zero_vec(K * num_chunks); val.par_chunks_mut(num_chunks) .enumerate() @@ -434,13 +580,19 @@ fn prove_read_checking_local( }); let mut val = MultilinearPolynomial::from(val); - // `ra` and `val` should match `ra_test` and `val_test`, respectively #[cfg(test)] { + // `ra` should match `ra_test` assert_eq!(ra.len(), ra_test.len()); for i in 0..ra.len() { assert_eq!(ra.get_bound_coeff(i), ra_test.get_bound_coeff(i)); } + // `wa` should match `wa_test` + assert_eq!(wa.len(), wa_test.len()); + for i in 0..wa.len() { + assert_eq!(wa.get_bound_coeff(i), wa_test.get_bound_coeff(i)); + } + // `val` should match `val_test` assert_eq!(val.len(), val_test.len()); for i in 0..val.len() { assert_eq!(val.get_bound_coeff(i), val_test.get_bound_coeff(i)); @@ -449,26 +601,35 @@ fn prove_read_checking_local( // Remaining rounds of sumcheck for _round in 0..num_rounds - chunk_size.log_2() { - let univariate_poly_evals: [F; 3] = if eq.len() > 1 { + let univariate_poly_evals: [F; 3] = if eq_r_prime.len() > 1 { // Not done binding cycle variables yet - (0..eq.len() / 2) + (0..eq_r_prime.len() / 2) .into_par_iter() .map(|j| { - let eq_evals = eq.sumcheck_evals(j, DEGREE, BindingOrder::LowToHigh); + let eq_r_prime_evals = + eq_r_prime.sumcheck_evals(j, DEGREE, BindingOrder::LowToHigh); + let wv_evals = wv.sumcheck_evals(j, DEGREE, BindingOrder::LowToHigh); let inner_sum_evals: [F; 3] = (0..K) .into_par_iter() .map(|k| { - let index = k * eq.len() / 2 + j; + let index = k * eq_r_prime.len() / 2 + j; let ra_evals = ra.sumcheck_evals(index, DEGREE, BindingOrder::LowToHigh); + let wa_evals = + wa.sumcheck_evals(index, DEGREE, BindingOrder::LowToHigh); let val_evals = val.sumcheck_evals(index, DEGREE, BindingOrder::LowToHigh); + let z_eq_r = z * eq_r.get_coeff(k); + [ - ra_evals[0] * val_evals[0], - ra_evals[1] * val_evals[1], - ra_evals[2] * val_evals[2], + ra_evals[0] * val_evals[0] + + z_eq_r * wa_evals[0] * (wv_evals[0] - val_evals[0]), + ra_evals[1] * val_evals[1] + + z_eq_r * wa_evals[1] * (wv_evals[1] - val_evals[1]), + ra_evals[2] * val_evals[2] + + z_eq_r * wa_evals[2] * (wv_evals[2] - val_evals[2]), ] }) .reduce( @@ -483,9 +644,9 @@ fn prove_read_checking_local( ); [ - eq_evals[0] * inner_sum_evals[0], - eq_evals[1] * inner_sum_evals[1], - eq_evals[2] * inner_sum_evals[2], + eq_r_prime_evals[0] * inner_sum_evals[0], + eq_r_prime_evals[1] * inner_sum_evals[1], + eq_r_prime_evals[2] * inner_sum_evals[2], ] }) .reduce( @@ -499,18 +660,27 @@ fn prove_read_checking_local( }, ) } else { - // Cycle variables are fully bound - let eq_r_prime_r = eq.final_sumcheck_claim(); + // Cycle variables are fully bound, so: + // eq(r', r_cycle) is a constant + let eq_r_prime_eval = eq_r_prime.final_sumcheck_claim(); + // ...and wv(r_cycle) is a constant + let wv_eval = wv.final_sumcheck_claim(); + let evals = (0..ra.len() / 2) .into_par_iter() .map(|k| { + let eq_r_evals = eq_r.sumcheck_evals(k, DEGREE, BindingOrder::LowToHigh); let ra_evals = ra.sumcheck_evals(k, DEGREE, BindingOrder::LowToHigh); + let wa_evals = wa.sumcheck_evals(k, DEGREE, BindingOrder::LowToHigh); let val_evals = val.sumcheck_evals(k, DEGREE, BindingOrder::LowToHigh); [ - ra_evals[0] * val_evals[0], - ra_evals[1] * val_evals[1], - ra_evals[2] * val_evals[2], + ra_evals[0] * val_evals[0] + + z * eq_r_evals[0] * wa_evals[0] * (wv_eval - val_evals[0]), + ra_evals[1] * val_evals[1] + + z * eq_r_evals[1] * wa_evals[1] * (wv_eval - val_evals[1]), + ra_evals[2] * val_evals[2] + + z * eq_r_evals[2] * wa_evals[2] * (wv_eval - val_evals[2]), ] }) .reduce( @@ -524,9 +694,9 @@ fn prove_read_checking_local( }, ); [ - eq_r_prime_r * evals[0], - eq_r_prime_r * evals[1], - eq_r_prime_r * evals[2], + eq_r_prime_eval * evals[0], + eq_r_prime_eval * evals[1], + eq_r_prime_eval * evals[2], ] }; @@ -542,26 +712,37 @@ fn prove_read_checking_local( compressed_polys.push(compressed_poly); let r_j = transcript.challenge_scalar::(); - r.push(r_j); + r_sumcheck.push(r_j); previous_claim = univariate_poly.evaluate(&r_j); // Bind polynomials - rayon::join( - || ra.bind_parallel(r_j, BindingOrder::LowToHigh), - || val.bind_parallel(r_j, BindingOrder::LowToHigh), - ); - if eq.len() > 1 { - eq.bind_parallel(r_j, BindingOrder::LowToHigh); + if eq_r_prime.len() > 1 { + // Bind a cycle variable j + // Note that `eq_r` is a polynomial over only the address variables, + // so it is not bound here + [&mut ra, &mut wa, &mut wv, &mut val, &mut eq_r_prime] + .into_par_iter() + .for_each(|poly| poly.bind_parallel(r_j, BindingOrder::LowToHigh)); + } else { + // Bind an address variable k + // Note that `wv` and `eq_r_prime` are polynomials over only the cycle + // variables, so they are not bound here + [&mut ra, &mut wa, &mut val, &mut eq_r] + .into_par_iter() + .for_each(|poly| poly.bind_parallel(r_j, BindingOrder::LowToHigh)); } } - ( - SumcheckInstanceProof::new(compressed_polys), - r, - ra.final_sumcheck_claim(), - val.final_sumcheck_claim(), - ) + TwistProof { + read_write_checking_sumcheck: SumcheckInstanceProof::new(compressed_polys), + ra_claim: ra.final_sumcheck_claim(), + rv_claim: rv_eval, + wa_claim: wa.final_sumcheck_claim(), + wv_claim: wv.final_sumcheck_claim(), + val_claim: val.final_sumcheck_claim(), + inc_claim: inc_eval, + } } /// Implements the sumcheck prover for the Val-evaluation sumcheck described in @@ -770,44 +951,84 @@ mod tests { } #[test] - fn read_checking_sumcheck_local() { + fn read_write_checking_sumcheck_local() { const K: usize = 16; const T: usize = 1 << 8; let mut rng = test_rng(); let mut registers = [0u32; K]; - let mut increments: Vec<(usize, i64)> = vec![]; - let mut rv: Vec = vec![]; + let mut read_addresses: Vec = Vec::with_capacity(T); + let mut read_values: Vec = Vec::with_capacity(T); + let mut write_addresses: Vec = Vec::with_capacity(T); + let mut write_values: Vec = Vec::with_capacity(T); + let mut write_increments: Vec = Vec::with_capacity(T); for _ in 0..T { - let k = rng.next_u32() as usize % K; - rv.push(registers[k]); - let new_value = rng.next_u32() % 10; - let inc = (new_value as i64) - (registers[k] as i64); - increments.push((k, inc)); - registers[k] = new_value; + // Random read register + let read_address = rng.next_u32() as usize % K; + // Random write register + let write_address = rng.next_u32() as usize % K; + read_addresses.push(read_address); + write_addresses.push(write_address); + // Read the value currently in the read register + read_values.push(registers[read_address]); + // Random write value + let write_value = rng.next_u32(); + write_values.push(write_value); + // The increment is the difference between the new value and the old value + let write_increment = (write_value as i64) - (registers[write_address] as i64); + write_increments.push(write_increment); + // Write the new value to the write register + registers[write_address] = write_value; } - let rv = MultilinearPolynomial::from(rv); let mut prover_transcript = KeccakTranscript::new(b"test_transcript"); + let r: Vec = prover_transcript.challenge_vector(K.log_2()); let r_prime: Vec = prover_transcript.challenge_vector(T.log_2()); - let rv_eval = rv.evaluate(&r_prime); - let (sumcheck_proof, _, ra_claim, val_claim) = - prove_read_checking_local(increments, &r_prime, K, rv_eval, &mut prover_transcript); + let twist_proof = prove_read_write_checking_local( + read_addresses, + read_values, + write_addresses, + write_values, + write_increments, + &r, + &r_prime, + &mut prover_transcript, + ); let mut verifier_transcript = KeccakTranscript::new(b"test_transcript"); verifier_transcript.compare_to(prover_transcript); + let _r: Vec = verifier_transcript.challenge_vector(K.log_2()); let _r_prime: Vec = verifier_transcript.challenge_vector(T.log_2()); - - let (sumcheck_claim, mut r) = sumcheck_proof - .verify(rv_eval, T.log_2() + K.log_2(), 3, &mut verifier_transcript) + let z: Fr = verifier_transcript.challenge_scalar(); + + let initial_sumcheck_claim = twist_proof.rv_claim + z * twist_proof.inc_claim; + + let (sumcheck_claim, mut r_sumcheck) = twist_proof + .read_write_checking_sumcheck + .verify( + initial_sumcheck_claim, + T.log_2() + K.log_2(), + 3, + &mut verifier_transcript, + ) .unwrap(); - r = r.into_iter().rev().collect(); - let (_r_address, r_cycle) = r.split_at(K.log_2()); + r_sumcheck = r_sumcheck.into_iter().rev().collect(); + let (r_address, r_cycle) = r_sumcheck.split_at(K.log_2()); + // eq(r', r_cycle) let eq_eval_cycle = EqPolynomial::new(r_prime).evaluate(r_cycle); - - assert_eq!(eq_eval_cycle * ra_claim * val_claim, sumcheck_claim); + // eq(r, r_address) + let eq_eval_address = EqPolynomial::new(r).evaluate(r_address); + + assert_eq!( + eq_eval_cycle * twist_proof.ra_claim * twist_proof.val_claim + + z * eq_eval_address + * eq_eval_cycle + * twist_proof.wa_claim + * (twist_proof.wv_claim - twist_proof.val_claim), + sumcheck_claim + ); } }