From 3195eb005cca18ec1d584cfabaecd5f6bd221957 Mon Sep 17 00:00:00 2001 From: Gus Gutoski Date: Tue, 3 Sep 2024 10:14:39 -0400 Subject: [PATCH] feat: multiplicity depend on payload size (#670) * compute actual multiplicity from max_multiplicity This commit defines a function `min_multiplicity` which can compute the actual multiplicity that will be used from `max_multiplicity` and `payload_len`. The original argument `multiplicity` has been renamed to `max_multiplicity` to indicate that this is an upper bound. * compute multiplicity * rename AdvzParams::multiplicity -> max_multiplicity * new test max_multiplicity * fix min_multiplicity * WIP polynomial_internal construct its own eval domain * WIP * WIP * fix: correctness test now passes * remove println * fix test * fix recover_payload * delete field Advz::eval_domain (yay) * fix test max_multiplicity * remove unneeded arg from min_multiplicity * remove unneeded arg from bytes_to_polys * tidy bytes_to_polys * refactor disperse, disperse_precompute into disperse_with_polys_and_commits * move code from evaluate_polys, assemble_shares into disperse_with_polys_and_commits * delete method Advz::polynomial * rename polynomial_internal -> interpolate_polynomial, return VidResult, tidy * delete field Advz::multi_open_domain, derive it on-the-fly from common.multiplicity (yay) * min_multiplicity return VidResult, replace panic with error * refactor eval_domain * remove stupid comment * remove more stupid things * modify test to allow nontrivial multiplicity * test use nontrivial multiplicity, fails * fix payload prover with nontrivial multiplicity * tests use Bn254 instead of Bls12_381 as per https://github.com/EspressoSystems/jellyfish/pull/670#discussion_r1736228179 * uncomment some test code as per https://github.com/EspressoSystems/jellyfish/pull/670#discussion_r1736335690 * address https://github.com/EspressoSystems/jellyfish/pull/670#discussion_r1736349890 * address https://github.com/EspressoSystems/jellyfish/pull/670#discussion_r1736358667 * typo as per https://github.com/EspressoSystems/jellyfish/pull/670#discussion_r1736383113 * delete debugging comment as per https://github.com/EspressoSystems/jellyfish/pull/670#discussion_r1737065017 * remove superfluous log message as per https://github.com/EspressoSystems/jellyfish/pull/670#discussion_r1738176332 * remove superfluous timer as per https://github.com/EspressoSystems/jellyfish/pull/670#discussion_r1738218808 * paranoia as per https://github.com/EspressoSystems/jellyfish/pull/670#discussion_r1738317888 * clarify comment as per https://github.com/EspressoSystems/jellyfish/pull/670#discussion_r1738324592 * stop being such a dumbass --------- Co-authored-by: Anders Konring --- vid/src/advz.rs | 578 +++++++++++++++++---------------- vid/src/advz/payload_prover.rs | 132 ++++++-- vid/src/advz/precomputable.rs | 87 +---- vid/src/advz/test.rs | 121 ++++++- vid/tests/vid/mod.rs | 22 +- 5 files changed, 521 insertions(+), 419 deletions(-) diff --git a/vid/src/advz.rs b/vid/src/advz.rs index 2951879a9..756654076 100644 --- a/vid/src/advz.rs +++ b/vid/src/advz.rs @@ -83,15 +83,9 @@ where { recovery_threshold: u32, num_storage_nodes: u32, - multiplicity: u32, + max_multiplicity: u32, ck: KzgProverParam, vk: KzgVerifierParam, - multi_open_domain: Radix2EvaluationDomain>, - - // TODO might be able to eliminate this field and instead use - // `EvaluationDomain::reindex_by_subdomain()` on `multi_open_domain` - // but that method consumes `other` and its doc is unclear. - eval_domain: Radix2EvaluationDomain>, // tuple of // - reference to the SRS/ProverParam loaded to GPU @@ -131,15 +125,20 @@ where ) -> VidResult { // TODO intelligent choice of multiplicity // https://github.com/EspressoSystems/jellyfish/issues/534 - let multiplicity = 1; + let max_multiplicity = 1; - Self::with_multiplicity_internal(num_storage_nodes, recovery_threshold, multiplicity, srs) + Self::with_multiplicity_internal( + num_storage_nodes, + recovery_threshold, + max_multiplicity, + srs, + ) } pub(crate) fn with_multiplicity_internal( num_storage_nodes: u32, // n (code rate: r = k/n) recovery_threshold: u32, // k - multiplicity: u32, // batch m chunks, keep the rate r = (m*k)/(m*n) + max_multiplicity: u32, // batch m chunks, keep the rate r = (m*k)/(m*n) srs: impl Borrow>, ) -> VidResult { if num_storage_nodes < recovery_threshold { @@ -149,49 +148,29 @@ where ))); } - if !multiplicity.is_power_of_two() { + // TODO TEMPORARY: enforce power-of-2 + // https://github.com/EspressoSystems/jellyfish/issues/668 + if !recovery_threshold.is_power_of_two() { return Err(VidError::Argument(format!( - "multiplicity {multiplicity} should be a power of two" + "recovery_threshold {recovery_threshold} should be a power of two" ))); } - - // erasure code params - let chunk_size = multiplicity * recovery_threshold; // message length m - let code_word_size = multiplicity * num_storage_nodes; // code word length n - let poly_degree = chunk_size - 1; - - let (ck, vk) = UnivariateKzgPCS::trim_fft_size(srs, poly_degree as usize).map_err(vid)?; - let multi_open_domain = UnivariateKzgPCS::::multi_open_rou_eval_domain( - poly_degree as usize, - code_word_size as usize, - ) - .map_err(vid)?; - let eval_domain = Radix2EvaluationDomain::new(chunk_size as usize).ok_or_else(|| { - VidError::Internal(anyhow::anyhow!( - "fail to construct domain of size {}", - chunk_size - )) - })?; - - // TODO TEMPORARY: enforce power-of-2 chunk size - // Remove this restriction after we get KZG in eval form - // https://github.com/EspressoSystems/jellyfish/issues/339 - if chunk_size as usize != eval_domain.size() { + if !max_multiplicity.is_power_of_two() { return Err(VidError::Argument(format!( - "recovery_threshold {} currently unsupported, round to {} instead", - chunk_size, - eval_domain.size() + "max_multiplicity {max_multiplicity} should be a power of two" ))); } + let supported_degree = + usize::try_from(max_multiplicity * recovery_threshold - 1).map_err(vid)?; + let (ck, vk) = UnivariateKzgPCS::trim_fft_size(srs, supported_degree).map_err(vid)?; + Ok(Self { recovery_threshold, num_storage_nodes, - multiplicity, + max_multiplicity, ck, vk, - multi_open_domain, - eval_domain, srs_on_gpu_and_cuda_stream: None, _pd: Default::default(), }) @@ -215,7 +194,7 @@ where /// # Errors /// Return [`VidError::Argument`] if /// - `num_storage_nodes < recovery_threshold` - /// - TEMPORARY `recovery_threshold` is not a power of two [github issue](https://github.com/EspressoSystems/jellyfish/issues/339) + /// - TEMPORARY `recovery_threshold` is not a power of two [github issue](https://github.com/EspressoSystems/jellyfish/issues/668) pub fn new( num_storage_nodes: u32, recovery_threshold: u32, @@ -224,21 +203,28 @@ where Self::new_internal(num_storage_nodes, recovery_threshold, srs) } - /// Like [`Advz::new`] except with a `multiplicity` arg. + /// Like [`Advz::new`] except with a `max_multiplicity` arg. /// - /// `multiplicity` is an implementation-specific optimization arg. - /// Each storage node gets `multiplicity` evaluations per polynomial. + /// `max_multiplicity` is an implementation-specific optimization arg. + /// Each storage node gets up to `max_multiplicity` evaluations per + /// polynomial. The actual multiplicity used will be the smallest value m + /// such that payload can fit (payload_len <= m * recovery_threshold). /// /// # Errors /// In addition to [`Advz::new`], return [`VidError::Argument`] if - /// - TEMPORARY `multiplicity` is not a power of two [github issue](https://github.com/EspressoSystems/jellyfish/issues/339) + /// - TEMPORARY `max_multiplicity` is not a power of two [github issue](https://github.com/EspressoSystems/jellyfish/issues/668) pub fn with_multiplicity( num_storage_nodes: u32, recovery_threshold: u32, - multiplicity: u32, + max_multiplicity: u32, srs: impl Borrow>, ) -> VidResult { - Self::with_multiplicity_internal(num_storage_nodes, recovery_threshold, multiplicity, srs) + Self::with_multiplicity_internal( + num_storage_nodes, + recovery_threshold, + max_multiplicity, + srs, + ) } } @@ -262,13 +248,13 @@ where pub fn with_multiplicity( num_storage_nodes: u32, recovery_threshold: u32, - multiplicity: u32, + max_multiplicity: u32, srs: impl Borrow>, ) -> VidResult { let mut advz = Self::with_multiplicity_internal( num_storage_nodes, recovery_threshold, - multiplicity, + max_multiplicity, srs, )?; advz.init_gpu_srs()?; @@ -281,7 +267,7 @@ where self.ck.powers_of_g.len() - 1, ) .map_err(vid)?; - self.srs_on_gpu_and_cuda_stream = Some((srs_on_gpu, warmup_new_stream().unwrap())); + self.srs_on_gpu_and_cuda_stream = Some((srs_on_gpu, warmup_new_stream().map_err(vid)?)); Ok(()) } } @@ -307,7 +293,7 @@ where evals: Vec>, #[serde(with = "canonical")] - // aggretate_proofs.len() equals self.multiplicity + // aggregate_proofs.len() equals multiplicity // TODO further aggregate into a single KZG proof. aggregate_proofs: Vec>, @@ -376,12 +362,10 @@ where &mut self, polys: &[DensePolynomial], ) -> VidResult>> { - // let mut srs_on_gpu = self.srs_on_gpu_and_cuda_stream.as_mut().unwrap().0; - // let stream = &self.srs_on_gpu_and_cuda_stream.as_ref().unwrap().1; if polys.is_empty() { return Ok(vec![]); } - let (srs_on_gpu, stream) = self.srs_on_gpu_and_cuda_stream.as_mut().unwrap(); // safe by construction + let (srs_on_gpu, stream) = self.srs_on_gpu_and_cuda_stream.as_mut().map_err(vid)?; // safe by construction as GPUCommittable>::gpu_batch_commit_with_loaded_prover_param( srs_on_gpu, polys, stream, ) @@ -407,9 +391,10 @@ where B: AsRef<[u8]>, { let payload = payload.as_ref(); - let bytes_to_polys_time = start_timer!(|| "encode payload bytes into polynomials"); - let polys = self.bytes_to_polys(payload); - end_timer!(bytes_to_polys_time); + let payload_byte_len = payload.len().try_into().map_err(vid)?; + let multiplicity = self.min_multiplicity(payload_byte_len)?; + let chunk_size = multiplicity * self.recovery_threshold; + let polys = self.bytes_to_polys(payload)?; let poly_commits_time = start_timer!(|| "batch poly commit"); let poly_commits = >::kzg_batch_commit(self, &polys)?; @@ -423,82 +408,10 @@ where B: AsRef<[u8]>, { let payload = payload.as_ref(); - let payload_byte_len = payload.len().try_into().map_err(vid)?; - let disperse_time = start_timer!(|| format!( - "VID disperse {} payload bytes to {} nodes", - payload_byte_len, self.num_storage_nodes - )); - let _chunk_size = self.multiplicity * self.recovery_threshold; - let code_word_size = self.multiplicity * self.num_storage_nodes; - - // partition payload into polynomial coefficients - let bytes_to_polys_time = start_timer!(|| "encode payload bytes into polynomials"); - let polys = self.bytes_to_polys(payload); - end_timer!(bytes_to_polys_time); - - // evaluate polynomials - let all_storage_node_evals_timer = start_timer!(|| format!( - "compute all storage node evals for {} polynomials with {} coefficients", - polys.len(), - _chunk_size - )); - let all_storage_node_evals = self.evaluate_polys(&polys)?; - end_timer!(all_storage_node_evals_timer); - - // vector commitment to polynomial evaluations - let all_evals_commit_timer = - start_timer!(|| "compute merkle root of all storage node evals"); - let all_evals_commit = - KzgEvalsMerkleTree::::from_elems(None, &all_storage_node_evals).map_err(vid)?; - end_timer!(all_evals_commit_timer); - - let common_timer = start_timer!(|| format!("compute {} KZG commitments", polys.len())); - let common = Common { - poly_commits: >::kzg_batch_commit(self, &polys)?, - all_evals_digest: all_evals_commit.commitment().digest(), - payload_byte_len, - num_storage_nodes: self.num_storage_nodes, - multiplicity: self.multiplicity, - }; - end_timer!(common_timer); - - let commit = Self::derive_commit( - &common.poly_commits, - payload_byte_len, - self.num_storage_nodes, - )?; - let pseudorandom_scalar = Self::pseudorandom_scalar(&common, &commit)?; - - // Compute aggregate polynomial as a pseudorandom linear combo of polynomial via - // evaluation of the polynomial whose coefficients are polynomials and whose - // input point is the pseudorandom scalar. - let aggregate_poly = - polynomial_eval(polys.iter().map(PolynomialMultiplier), pseudorandom_scalar); - - let agg_proofs_timer = start_timer!(|| format!( - "compute aggregate proofs for {} storage nodes", - self.num_storage_nodes - )); - let aggregate_proofs = UnivariateKzgPCS::multi_open_rou_proofs( - &self.ck, - &aggregate_poly, - code_word_size as usize, - &self.multi_open_domain, - ) - .map_err(vid)?; - end_timer!(agg_proofs_timer); - - let assemblage_timer = start_timer!(|| "assemble shares for dispersal"); - let shares = - self.assemble_shares(all_storage_node_evals, aggregate_proofs, all_evals_commit)?; - end_timer!(assemblage_timer); - end_timer!(disperse_time); + let polys = self.bytes_to_polys(payload)?; + let poly_commits = >::kzg_batch_commit(self, &polys)?; - Ok(VidDisperse { - shares, - common, - commit, - }) + self.disperse_with_polys_and_commits(payload, polys, poly_commits) } fn verify_share( @@ -514,21 +427,22 @@ where common.num_storage_nodes, self.num_storage_nodes ))); } - if common.multiplicity != self.multiplicity { + let multiplicity = + self.min_multiplicity(common.payload_byte_len.try_into().map_err(vid)?)?; + if common.multiplicity != multiplicity { return Err(VidError::Argument(format!( - "common multiplicity {} differs from self {}", - common.multiplicity, self.multiplicity + "common multiplicity {} differs from derived min {}", + common.multiplicity, multiplicity ))); } - let multiplicity: usize = common.multiplicity.try_into().map_err(vid)?; - if share.evals.len() / multiplicity != common.poly_commits.len() { + if share.evals.len() / multiplicity as usize != common.poly_commits.len() { return Err(VidError::Argument(format!( "number of share evals / multiplicity {}/{} differs from number of common polynomial commitments {}", share.evals.len(), multiplicity, common.poly_commits.len() ))); } - if share.eval_proofs.len() != multiplicity { + if share.eval_proofs.len() != multiplicity as usize { return Err(VidError::Argument(format!( "number of eval_proofs {} differs from common multiplicity {}", share.eval_proofs.len(), @@ -543,10 +457,10 @@ where } // verify eval proofs - for i in 0..self.multiplicity { + for i in 0..multiplicity { if KzgEvalsMerkleTree::::verify( common.all_evals_digest, - &KzgEvalsMerkleTreeIndex::::from((share.index * self.multiplicity) + i), + &KzgEvalsMerkleTreeIndex::::from((share.index * multiplicity) + i), &share.eval_proofs[i as usize], ) .map_err(vid)? @@ -577,8 +491,9 @@ where // // some boilerplate needed to accommodate builds without `parallel` // feature. - let multiplicities = Vec::from_iter((0..self.multiplicity as usize)); + let multiplicities = Vec::from_iter((0..multiplicity as usize)); let polys_len = common.poly_commits.len(); + let multi_open_domain = self.multi_open_domain(multiplicity)?; let verification_iter = parallelizable_slice_iter(&multiplicities).map(|i| { let range = i * polys_len..(i + 1) * polys_len; let aggregate_eval = polynomial_eval( @@ -599,9 +514,7 @@ where Ok(UnivariateKzgPCS::verify( &self.vk, &aggregate_poly_commit, - &self - .multi_open_domain - .element((share.index as usize * multiplicity) + i), + &multi_open_domain.element((share.index * multiplicity) as usize + i), &aggregate_eval, &share.aggregate_proofs[*i], ) @@ -625,6 +538,7 @@ where } fn recover_payload(&self, shares: &[Self::Share], common: &Self::Common) -> VidResult> { + // check args if shares.len() < self.recovery_threshold as usize { return Err(VidError::Argument(format!( "not enough shares {}, expected at least {}", @@ -639,7 +553,7 @@ where ))); } - // all shares must have equal evals len + // check args: all shares must have equal evals len let num_evals = shares .first() .ok_or_else(|| VidError::Argument("shares is empty".into()))? @@ -658,26 +572,29 @@ where share.evals.len() ))); } - if num_evals != self.multiplicity as usize * common.poly_commits.len() { + if num_evals != common.multiplicity as usize * common.poly_commits.len() { return Err(VidError::Argument(format!( "num_evals should be (multiplicity * poly_commits): {} but is instead: {}", - self.multiplicity as usize * common.poly_commits.len(), + common.multiplicity as usize * common.poly_commits.len(), num_evals, ))); } - let chunk_size = self.multiplicity * self.recovery_threshold; - let num_polys = num_evals / self.multiplicity as usize; - let elems_capacity = num_polys * chunk_size as usize; - let mut elems = Vec::with_capacity(elems_capacity); + // convenience quantities + let chunk_size = + usize::try_from(common.multiplicity * self.recovery_threshold).map_err(vid)?; + let num_polys = common.poly_commits.len(); + let elems_capacity = num_polys * chunk_size; + let fft_domain = Self::eval_domain(chunk_size)?; + let mut elems = Vec::with_capacity(elems_capacity); let mut evals = Vec::with_capacity(num_evals); for p in 0..num_polys { for share in shares { // extract all evaluations for polynomial p from the share - for m in 0..self.multiplicity as usize { + for m in 0..common.multiplicity as usize { evals.push(( - (share.index * self.multiplicity) as usize + m, + (share.index * common.multiplicity) as usize + m, share.evals[(m * num_polys) + p], )) } @@ -685,14 +602,14 @@ where let mut coeffs = reed_solomon_erasure_decode_rou( mem::take(&mut evals), chunk_size as usize, - &self.multi_open_domain, + &self.multi_open_domain(common.multiplicity)?, ) .map_err(vid)?; // TODO TEMPORARY: use FFT to encode polynomials in eval form // Remove these FFTs after we get KZG in eval form // https://github.com/EspressoSystems/jellyfish/issues/339 - self.eval_domain.fft_in_place(&mut coeffs); + fft_domain.fft_in_place(&mut coeffs); elems.append(&mut coeffs); } @@ -738,48 +655,142 @@ where SrsRef: Sync, AdvzInternal: MaybeGPU, { - fn evaluate_polys( + fn disperse_with_polys_and_commits( &self, - polys: &[DensePolynomial<::ScalarField>], - ) -> Result::ScalarField>>, VidError> - where - E: Pairing, - H: HasherDigest, - { - let code_word_size = (self.num_storage_nodes * self.multiplicity) as usize; - let mut all_storage_node_evals = vec![Vec::with_capacity(polys.len()); code_word_size]; - // this is to avoid `SrsRef` not implementing `Sync` problem, - // instead of sending entire `self` cross thread, we only send a ref which is - // Sync - let multi_open_domain_ref = &self.multi_open_domain; - - let all_poly_evals = parallelizable_slice_iter(polys) - .map(|poly| { - UnivariateKzgPCS::::multi_open_rou_evals( - poly, - code_word_size, - multi_open_domain_ref, - ) - .map_err(vid) - }) - .collect::, VidError>>()?; + payload: &[u8], + polys: Vec::ScalarField>>, + poly_commits: Vec>, + ) -> VidResult> { + let payload_byte_len = payload.len().try_into().map_err(vid)?; + let disperse_time = start_timer!(|| format!( + "VID disperse {} payload bytes to {} nodes", + payload_byte_len, self.num_storage_nodes + )); + let multiplicity = self.min_multiplicity(payload.len())?; + let code_word_size = usize::try_from(multiplicity * self.num_storage_nodes).map_err(vid)?; + let multi_open_domain = self.multi_open_domain(multiplicity)?; - for poly_evals in all_poly_evals { - for (storage_node_evals, poly_eval) in all_storage_node_evals - .iter_mut() - .zip(poly_evals.into_iter()) - { - storage_node_evals.push(poly_eval); + // evaluate polynomials + let all_storage_node_evals_timer = start_timer!(|| format!( + "compute all storage node evals for {} polynomials with {} coefficients", + polys.len(), + multiplicity * self.recovery_threshold + )); + let all_storage_node_evals = { + let mut all_storage_node_evals = vec![Vec::with_capacity(polys.len()); code_word_size]; + let all_poly_evals = parallelizable_slice_iter(&polys) + .map(|poly| { + UnivariateKzgPCS::::multi_open_rou_evals( + poly, + code_word_size, + &multi_open_domain, + ) + .map_err(vid) + }) + .collect::, VidError>>()?; + + for poly_evals in all_poly_evals { + for (storage_node_evals, poly_eval) in all_storage_node_evals + .iter_mut() + .zip(poly_evals.into_iter()) + { + storage_node_evals.push(poly_eval); + } } - } - // sanity checks - assert_eq!(all_storage_node_evals.len(), code_word_size); - for storage_node_evals in all_storage_node_evals.iter() { - assert_eq!(storage_node_evals.len(), polys.len()); - } + // sanity checks + assert_eq!(all_storage_node_evals.len(), code_word_size); + for storage_node_evals in all_storage_node_evals.iter() { + assert_eq!(storage_node_evals.len(), polys.len()); + } + + all_storage_node_evals + }; + end_timer!(all_storage_node_evals_timer); + + // vector commitment to polynomial evaluations + let all_evals_commit_timer = + start_timer!(|| "compute merkle root of all storage node evals"); + let all_evals_commit = + KzgEvalsMerkleTree::::from_elems(None, &all_storage_node_evals).map_err(vid)?; + end_timer!(all_evals_commit_timer); + + let common = Common { + poly_commits, + all_evals_digest: all_evals_commit.commitment().digest(), + payload_byte_len, + num_storage_nodes: self.num_storage_nodes, + multiplicity, + }; + + let commit = Self::derive_commit( + &common.poly_commits, + payload_byte_len, + self.num_storage_nodes, + )?; + let pseudorandom_scalar = Self::pseudorandom_scalar(&common, &commit)?; + + // Compute aggregate polynomial as a pseudorandom linear combo of polynomial via + // evaluation of the polynomial whose coefficients are polynomials and whose + // input point is the pseudorandom scalar. + let aggregate_poly = + polynomial_eval(polys.iter().map(PolynomialMultiplier), pseudorandom_scalar); + + let agg_proofs_timer = start_timer!(|| format!( + "compute aggregate proofs for {} storage nodes", + self.num_storage_nodes + )); + let aggregate_proofs = UnivariateKzgPCS::multi_open_rou_proofs( + &self.ck, + &aggregate_poly, + code_word_size as usize, + &multi_open_domain, + ) + .map_err(vid)?; + end_timer!(agg_proofs_timer); - Ok(all_storage_node_evals) + let assemblage_timer = start_timer!(|| "assemble shares for dispersal"); + let shares: Vec<_> = { + // compute share data + let share_data = all_storage_node_evals + .into_iter() + .zip(aggregate_proofs) + .enumerate() + .map(|(i, (eval, proof))| { + let eval_proof = all_evals_commit + .lookup(KzgEvalsMerkleTreeIndex::::from(i as u64)) + .expect_ok() + .map_err(vid)? + .1; + Ok((eval, proof, eval_proof)) + }) + .collect::, VidError>>()?; + + // split share data into chunks of size multiplicity + share_data + .into_iter() + .chunks(multiplicity as usize) + .into_iter() + .enumerate() + .map(|(index, chunk)| { + let (evals, proofs, eval_proofs): (Vec<_>, _, _) = chunk.multiunzip(); + Share { + index: index as u32, + evals: evals.into_iter().flatten().collect::>(), + aggregate_proofs: proofs, + eval_proofs, + } + }) + .collect() + }; + end_timer!(assemblage_timer); + end_timer!(disperse_time); + + Ok(VidDisperse { + shares, + common, + commit, + }) } fn pseudorandom_scalar( @@ -809,33 +820,46 @@ where Ok(PrimeField::from_le_bytes_mod_order(&hasher.finalize())) } - fn bytes_to_polys(&self, payload: &[u8]) -> Vec::ScalarField>> + /// Partition payload into polynomial coefficients + fn bytes_to_polys( + &self, + payload: &[u8], + ) -> VidResult::ScalarField>>> where E: Pairing, { - let chunk_size = (self.recovery_threshold * self.multiplicity) as usize; let elem_bytes_len = bytes_to_field::elem_byte_capacity::<::ScalarField>(); - let eval_domain_ref = &self.eval_domain; + let domain_size = + usize::try_from(self.min_multiplicity(payload.len())? * self.recovery_threshold) + .map_err(vid)?; - parallelizable_chunks(payload, chunk_size * elem_bytes_len) + let bytes_to_polys_time = start_timer!(|| "encode payload bytes into polynomials"); + let result = parallelizable_chunks(payload, domain_size * elem_bytes_len) .map(|chunk| { - Self::polynomial_internal( - eval_domain_ref, - chunk_size, - bytes_to_field::<_, KzgEval>(chunk), - ) + Self::interpolate_polynomial(bytes_to_field::<_, KzgEval>(chunk), domain_size) }) - .collect() + .collect::>>(); + end_timer!(bytes_to_polys_time); + result } - // This is an associated function, not a method, doesn't take in `self`, thus - // more friendly to cross-thread `Sync`, especially when on of the generic - // param of `Self` didn't implement `Sync` - fn polynomial_internal( - domain_ref: &Radix2EvaluationDomain>, - chunk_size: usize, - coeffs: I, - ) -> KzgPolynomial + /// Consume `evals` and return a polynomial that interpolates `evals` on a + /// evaluation domain of size `domain_size`. + /// + /// Return an error if the length of `evals` exceeds `domain_size`. + /// + /// The degree-plus-1 of the returned polynomial is always a power of two + /// because: + /// + /// - We use FFT to interpolate, so `domain_size` is rounded up to the next + /// power of two. + /// - [`KzgPolynomial`] implementation is stored in coefficient form. + /// + /// See https://github.com/EspressoSystems/jellyfish/issues/339 + /// + /// Why is this method an associated function of `Self`? Because we want to + /// use a generic parameter of `Self`. + fn interpolate_polynomial(evals: I, domain_size: usize) -> VidResult> where I: Iterator, I::Item: Borrow>, @@ -843,31 +867,60 @@ where // TODO TEMPORARY: use FFT to encode polynomials in eval form // Remove these FFTs after we get KZG in eval form // https://github.com/EspressoSystems/jellyfish/issues/339 - let mut coeffs_vec: Vec<_> = coeffs.map(|c| *c.borrow()).collect(); - let pre_fft_len = coeffs_vec.len(); - EvaluationDomain::ifft_in_place(domain_ref, &mut coeffs_vec); - - // sanity check: the fft did not resize coeffs. - // If pre_fft_len != self.recovery_threshold * self.multiplicity - // then we were not given the correct number of coeffs. In that case - // coeffs.len() could be anything, so there's nothing to sanity check. - if pre_fft_len == chunk_size { - assert_eq!(coeffs_vec.len(), pre_fft_len); + let mut evals_vec: Vec<_> = evals.map(|c| *c.borrow()).collect(); + let pre_fft_len = evals_vec.len(); + if pre_fft_len > domain_size { + return Err(VidError::Internal(anyhow::anyhow!( + "number of evals {} exceeds domain_size {}", + pre_fft_len, + domain_size + ))); } + let domain = Self::eval_domain(domain_size)?; - DenseUVPolynomial::from_coefficients_vec(coeffs_vec) + domain.ifft_in_place(&mut evals_vec); + + // sanity: the fft did not resize evals. If pre_fft_len < domain_size + // then we were given too few evals, in which case there's nothing to + // sanity check. + if pre_fft_len == domain_size && pre_fft_len != evals_vec.len() { + return Err(VidError::Internal(anyhow::anyhow!( + "unexpected output resize from {pre_fft_len} to {}", + evals_vec.len() + ))); + } + + Ok(DenseUVPolynomial::from_coefficients_vec(evals_vec)) } - fn polynomial(&self, coeffs: I) -> KzgPolynomial - where - I: Iterator, - I::Item: Borrow>, - { - Self::polynomial_internal( - &self.eval_domain, - (self.recovery_threshold * self.multiplicity) as usize, - coeffs, - ) + fn min_multiplicity(&self, payload_byte_len: usize) -> VidResult { + let elem_bytes_len = bytes_to_field::elem_byte_capacity::<::ScalarField>(); + let elems: u32 = payload_byte_len + .div_ceil(elem_bytes_len) + .try_into() + .map_err(vid)?; + if self.recovery_threshold * self.max_multiplicity < elems { + // payload is large. no change in multiplicity needed. + return Ok(self.max_multiplicity); + } + + // payload is small: choose the smallest `m` such that `0 < m < + // multiplicity` and the entire payload fits into `m * + // recovery_threshold` elements. + let m = elems.div_ceil(self.recovery_threshold.max(1)).max(1); + + // TODO TEMPORARY: enforce power-of-2 + // https://github.com/EspressoSystems/jellyfish/issues/668 + // + // Round up to the nearest power of 2. + // + // After the above issue is fixed: delete the following code and return + // `m` from above. + if m <= 1 { + Ok(1) + } else { + Ok(1 << ((m - 1).ilog2() + 1)) + } } /// Derive a commitment from whatever data is needed. @@ -903,53 +956,24 @@ where Ok(hasher.finalize().into()) } - /// Assemble shares from evaluations and proofs. - /// - /// Each share contains (for multiplicity m): - /// 1. (m * num_poly) evaluations. - /// 2. a collection of m KZG proofs. TODO KZG aggregation https://github.com/EspressoSystems/jellyfish/issues/356 - /// 3. m merkle tree membership proofs. - fn assemble_shares( + fn multi_open_domain( &self, - all_storage_node_evals: Vec::ScalarField>>, - aggregate_proofs: Vec>, - all_evals_commit: KzgEvalsMerkleTree, - ) -> Result>, VidError> - where - E: Pairing, - H: HasherDigest, - { - // compute share data - let share_data = all_storage_node_evals - .into_iter() - .zip(aggregate_proofs) - .enumerate() - .map(|(i, (eval, proof))| { - let eval_proof = all_evals_commit - .lookup(KzgEvalsMerkleTreeIndex::::from(i as u64)) - .expect_ok() - .map_err(vid)? - .1; - Ok((eval, proof, eval_proof)) - }) - .collect::, VidError>>()?; + multiplicity: u32, + ) -> VidResult::ScalarField>> { + let chunk_size = usize::try_from(multiplicity * self.recovery_threshold).map_err(vid)?; + let code_word_size = usize::try_from(multiplicity * self.num_storage_nodes).map_err(vid)?; + UnivariateKzgPCS::::multi_open_rou_eval_domain(chunk_size - 1, code_word_size) + .map_err(vid) + } - // split share data into chunks of size multiplicity - Ok(share_data - .into_iter() - .chunks(self.multiplicity as usize) - .into_iter() - .enumerate() - .map(|(index, chunk)| { - let (evals, proofs, eval_proofs): (Vec<_>, _, _) = chunk.multiunzip(); - Share { - index: index as u32, - evals: evals.into_iter().flatten().collect::>(), - aggregate_proofs: proofs, - eval_proofs, - } - }) - .collect()) + fn eval_domain( + domain_size: usize, + ) -> VidResult::ScalarField>> { + Radix2EvaluationDomain::>::new(domain_size).ok_or_else(|| { + VidError::Internal(anyhow::anyhow!( + "fail to construct domain of size {domain_size}" + )) + }) } } diff --git a/vid/src/advz/payload_prover.rs b/vid/src/advz/payload_prover.rs index 93d1cacb8..46baae595 100644 --- a/vid/src/advz/payload_prover.rs +++ b/vid/src/advz/payload_prover.rs @@ -25,7 +25,7 @@ use crate::{ }; use anyhow::anyhow; use ark_ec::pairing::Pairing; -use ark_poly::EvaluationDomain; +use ark_poly::{EvaluationDomain, Radix2EvaluationDomain}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::{format, ops::Range}; use itertools::Itertools; @@ -81,26 +81,38 @@ where check_range_nonempty_and_in_bounds(payload.len(), &range)?; // index conversion + let multiplicity = self.min_multiplicity(payload.len())?; let range_elem = self.range_byte_to_elem(&range); - let range_poly = self.range_elem_to_poly(&range_elem); + let range_poly = self.range_elem_to_poly(&range_elem, multiplicity); let range_elem_byte = self.range_elem_to_byte_clamped(&range_elem, payload.len()); - let range_poly_byte = self.range_poly_to_byte_clamped(&range_poly, payload.len()); - let offset_elem = self.offset_poly_to_elem(range_poly.start, range_elem.start); + let range_poly_byte = + self.range_poly_to_byte_clamped(&range_poly, payload.len(), multiplicity); + let offset_elem = + self.offset_poly_to_elem(range_poly.start, range_elem.start, multiplicity); let final_points_range_end = - self.final_poly_points_range_end(range_elem.len(), offset_elem); + self.final_poly_points_range_end(range_elem.len(), offset_elem, multiplicity); // prepare list of input points - // perf: we might not need all these points - let points: Vec<_> = self.eval_domain.elements().collect(); + // + // perf: if payload is small enough to fit into a single polynomial then + // we don't need all the points in this domain. + let points: Vec<_> = Self::eval_domain( + usize::try_from(self.recovery_threshold * multiplicity).map_err(vid)?, + )? + .elements() + .collect(); let elems_iter = bytes_to_field::<_, KzgEval>(&payload[range_poly_byte]); let mut proofs = Vec::with_capacity(range_poly.len() * points.len()); for (i, evals_iter) in elems_iter - .chunks(self.recovery_threshold as usize) + .chunks((self.recovery_threshold * multiplicity) as usize) .into_iter() .enumerate() { - let poly = self.polynomial(evals_iter); + let poly = Self::interpolate_polynomial( + evals_iter, + (self.recovery_threshold * multiplicity) as usize, + )?; let points_range = Range { // first polynomial? skip to the start of the proof range start: if i == 0 { offset_elem } else { 0 }, @@ -152,14 +164,24 @@ where // index conversion let range_elem = self.range_byte_to_elem(&stmt.range); - let range_poly = self.range_elem_to_poly(&range_elem); - let offset_elem = self.offset_poly_to_elem(range_poly.start, range_elem.start); - let final_points_range_end = - self.final_poly_points_range_end(range_elem.len(), offset_elem); + let range_poly = self.range_elem_to_poly(&range_elem, stmt.common.multiplicity); + let offset_elem = + self.offset_poly_to_elem(range_poly.start, range_elem.start, stmt.common.multiplicity); + let final_points_range_end = self.final_poly_points_range_end( + range_elem.len(), + offset_elem, + stmt.common.multiplicity, + ); // prepare list of input points - // perf: we might not need all these points - let points: Vec<_> = self.eval_domain.elements().collect(); + // + // perf: if payload is small enough to fit into a single polynomial then + // we don't need all the points in this domain. + let points: Vec<_> = Self::eval_domain( + usize::try_from(self.recovery_threshold * stmt.common.multiplicity).map_err(vid)?, + )? + .elements() + .collect(); // verify proof let mut cur_proof_index = 0; @@ -218,11 +240,14 @@ where check_range_nonempty_and_in_bounds(payload.len(), &range)?; // index conversion + let multiplicity = self.min_multiplicity(payload.len())?; let range_elem = self.range_byte_to_elem(&range); - let range_poly = self.range_elem_to_poly(&range_elem); + let range_poly = self.range_elem_to_poly(&range_elem, multiplicity); let range_elem_byte = self.range_elem_to_byte_clamped(&range_elem, payload.len()); - let range_poly_byte = self.range_poly_to_byte_clamped(&range_poly, payload.len()); - let offset_elem = self.offset_poly_to_elem(range_poly.start, range_elem.start); + let range_poly_byte = + self.range_poly_to_byte_clamped(&range_poly, payload.len(), multiplicity); + let offset_elem = + self.offset_poly_to_elem(range_poly.start, range_elem.start, multiplicity); // compute the prefix and suffix elems let mut elems_iter = bytes_to_field::<_, KzgEval>(payload[range_poly_byte].iter()); @@ -245,7 +270,7 @@ where Self::check_stmt_consistency(&stmt)?; // index conversion - let range_poly = self.range_byte_to_poly(&stmt.range); + let range_poly = self.range_byte_to_poly(&stmt.range, stmt.common.multiplicity); // rebuild the needed payload elements from statement and proof let elems_iter = proof @@ -260,14 +285,16 @@ where .chain(proof.suffix_bytes.iter()), )) .chain(proof.suffix_elems.iter().cloned()); - // rebuild the poly commits, check against `common` for (commit_index, evals_iter) in range_poly.into_iter().zip( elems_iter - .chunks(self.recovery_threshold as usize) + .chunks((self.recovery_threshold * stmt.common.multiplicity) as usize) .into_iter(), ) { - let poly = self.polynomial(evals_iter); + let poly = Self::interpolate_polynomial( + evals_iter, + (stmt.common.multiplicity * self.recovery_threshold) as usize, + )?; let poly_commit = UnivariateKzgPCS::commit(&self.ck, &poly).map_err(vid)?; if poly_commit != stmt.common.poly_commits[commit_index] { return Ok(Err(())); @@ -295,34 +322,49 @@ where ..result } } - fn range_elem_to_poly(&self, range: &Range) -> Range { - range_coarsen(range, self.recovery_threshold as usize) + fn range_elem_to_poly(&self, range: &Range, multiplicity: u32) -> Range { + range_coarsen(range, (self.recovery_threshold * multiplicity) as usize) } - fn range_byte_to_poly(&self, range: &Range) -> Range { + fn range_byte_to_poly(&self, range: &Range, multiplicity: u32) -> Range { range_coarsen( range, - self.recovery_threshold as usize * elem_byte_capacity::>(), + (self.recovery_threshold * multiplicity) as usize * elem_byte_capacity::>(), ) } - fn range_poly_to_byte_clamped(&self, range: &Range, len: usize) -> Range { + fn range_poly_to_byte_clamped( + &self, + range: &Range, + len: usize, + multiplicity: u32, + ) -> Range { let result = range_refine( range, - self.recovery_threshold as usize * elem_byte_capacity::>(), + (self.recovery_threshold * multiplicity) as usize * elem_byte_capacity::>(), ); Range { end: ark_std::cmp::min(result.end, len), ..result } } - fn offset_poly_to_elem(&self, range_poly_start: usize, range_elem_start: usize) -> usize { + fn offset_poly_to_elem( + &self, + range_poly_start: usize, + range_elem_start: usize, + multiplicity: u32, + ) -> usize { let start_poly_byte = index_refine( range_poly_start, - self.recovery_threshold as usize * elem_byte_capacity::>(), + (self.recovery_threshold * multiplicity) as usize * elem_byte_capacity::>(), ); range_elem_start - index_coarsen(start_poly_byte, elem_byte_capacity::>()) } - fn final_poly_points_range_end(&self, range_elem_len: usize, offset_elem: usize) -> usize { - (range_elem_len + offset_elem - 1) % self.recovery_threshold as usize + 1 + fn final_poly_points_range_end( + &self, + range_elem_len: usize, + offset_elem: usize, + multiplicity: u32, + ) -> usize { + (range_elem_len + offset_elem - 1) % (self.recovery_threshold * multiplicity) as usize + 1 } fn check_stmt_consistency(stmt: &Statement) -> VidResult<()> { @@ -403,17 +445,24 @@ mod tests { H: HasherDigest, { // play with these items - let (recovery_threshold, num_storage_nodes) = (4, 6); + let (recovery_threshold, num_storage_nodes, max_multiplicity) = (4, 6, 2); let num_polys = 3; let num_random_cases = 20; // more items as a function of the above - let payload_elems_len = num_polys * recovery_threshold as usize; + let poly_elems_len = recovery_threshold as usize * max_multiplicity as usize; + let payload_elems_len = num_polys * poly_elems_len; + let poly_bytes_len = poly_elems_len * elem_byte_capacity::(); let payload_bytes_base_len = payload_elems_len * elem_byte_capacity::(); - let poly_bytes_len = recovery_threshold as usize * elem_byte_capacity::(); let mut rng = jf_utils::test_rng(); let srs = init_srs(payload_elems_len, &mut rng); - let mut advz = Advz::::new(num_storage_nodes, recovery_threshold, srs).unwrap(); + let mut advz = Advz::::with_multiplicity( + num_storage_nodes, + recovery_threshold, + max_multiplicity, + srs, + ) + .unwrap(); // TEST: different payload byte lengths let payload_byte_len_noise_cases = vec![0, poly_bytes_len / 2, poly_bytes_len - 1]; @@ -444,9 +493,15 @@ mod tests { }; let all_cases = [(edge_cases, "edge"), (random_cases, "rand")]; + // at least one test case should have nontrivial multiplicity + let mut nontrivial_multiplicity = false; + for payload_len_case in payload_len_cases { let payload = init_random_payload(payload_len_case, &mut rng); let d = advz.disperse(&payload).unwrap(); + if d.common.multiplicity > 1 { + nontrivial_multiplicity = true; + } println!("payload byte len case: {}", payload.len()); for cases in all_cases.iter() { @@ -513,6 +568,11 @@ mod tests { } } + assert!( + nontrivial_multiplicity, + "at least one payload size should use multiplicity > 1" + ); + fn make_edge_cases(min: usize, max: usize) -> Vec> { vec![ Range { diff --git a/vid/src/advz/precomputable.rs b/vid/src/advz/precomputable.rs index 046367408..777bcf50a 100644 --- a/vid/src/advz/precomputable.rs +++ b/vid/src/advz/precomputable.rs @@ -39,7 +39,8 @@ where B: AsRef<[u8]>, { let payload = payload.as_ref(); - let polys = self.bytes_to_polys(payload); + let multiplicity = self.min_multiplicity(payload.len()); + let polys = self.bytes_to_polys(payload)?; let poly_commits: Vec> = UnivariateKzgPCS::batch_commit(&self.ck, &polys).map_err(vid)?; Ok(( @@ -57,88 +58,10 @@ where B: AsRef<[u8]>, { let payload = payload.as_ref(); - let payload_byte_len = payload.len().try_into().map_err(vid)?; - let disperse_time = start_timer!(|| ark_std::format!( - "(PRECOMPUTE): VID disperse {} payload bytes to {} nodes", - payload_byte_len, - self.num_storage_nodes - )); - let _chunk_size = self.multiplicity * self.recovery_threshold; - let code_word_size = self.multiplicity * self.num_storage_nodes; + let polys = self.bytes_to_polys(payload)?; + let poly_commits = data.poly_commits.clone(); - // partition payload into polynomial coefficients - // and count `elems_len` for later - let bytes_to_polys_time = start_timer!(|| "encode payload bytes into polynomials"); - let polys = self.bytes_to_polys(payload); - end_timer!(bytes_to_polys_time); - - // evaluate polynomials - let all_storage_node_evals_timer = start_timer!(|| ark_std::format!( - "compute all storage node evals for {} polynomials with {} coefficients", - polys.len(), - _chunk_size - )); - let all_storage_node_evals = self.evaluate_polys(&polys)?; - end_timer!(all_storage_node_evals_timer); - - // vector commitment to polynomial evaluations - // TODO why do I need to compute the height of the merkle tree? - let all_evals_commit_timer = - start_timer!(|| "compute merkle root of all storage node evals"); - let all_evals_commit = - KzgEvalsMerkleTree::::from_elems(None, &all_storage_node_evals).map_err(vid)?; - end_timer!(all_evals_commit_timer); - - let common_timer = start_timer!(|| ark_std::format!( - "(PRECOMPUTE): compute {} KZG commitments", - polys.len() - )); - let common = Common { - poly_commits: data.poly_commits.clone(), - all_evals_digest: all_evals_commit.commitment().digest(), - payload_byte_len, - num_storage_nodes: self.num_storage_nodes, - multiplicity: self.multiplicity, - }; - end_timer!(common_timer); - - let commit = Self::derive_commit( - &common.poly_commits, - payload_byte_len, - self.num_storage_nodes, - )?; - let pseudorandom_scalar = Self::pseudorandom_scalar(&common, &commit)?; - - // Compute aggregate polynomial as a pseudorandom linear combo of polynomial via - // evaluation of the polynomial whose coefficients are polynomials and whose - // input point is the pseudorandom scalar. - let aggregate_poly = - polynomial_eval(polys.iter().map(PolynomialMultiplier), pseudorandom_scalar); - - let agg_proofs_timer = start_timer!(|| ark_std::format!( - "compute aggregate proofs for {} storage nodes", - self.num_storage_nodes - )); - let aggregate_proofs = UnivariateKzgPCS::multi_open_rou_proofs( - &self.ck, - &aggregate_poly, - code_word_size as usize, - &self.multi_open_domain, - ) - .map_err(vid)?; - end_timer!(agg_proofs_timer); - - let assemblage_timer = start_timer!(|| "assemble shares for dispersal"); - let shares = - self.assemble_shares(all_storage_node_evals, aggregate_proofs, all_evals_commit)?; - end_timer!(assemblage_timer); - end_timer!(disperse_time); - - Ok(VidDisperse { - shares, - common, - commit, - }) + self.disperse_with_polys_and_commits(payload, polys, poly_commits) } } diff --git a/vid/src/advz/test.rs b/vid/src/advz/test.rs index d2aab1843..48a8a7c8b 100644 --- a/vid/src/advz/test.rs +++ b/vid/src/advz/test.rs @@ -1,5 +1,4 @@ use super::{VidError::Argument, *}; -use ark_bls12_381::Bls12_381; use ark_bn254::Bn254; use ark_std::{ rand::{CryptoRng, RngCore}, @@ -136,7 +135,7 @@ fn sad_path_verify_share_corrupt_commit() { // 1 corrupt commit, poly_commit let common_1_poly_corruption = { let mut corrupted = common.clone(); - corrupted.poly_commits[0] = ::G1Affine::zero().into(); + corrupted.poly_commits[0] = ::G1Affine::zero().into(); corrupted }; assert_arg_err( @@ -235,8 +234,9 @@ fn sad_path_recover_payload_corrupt_shares() { // corrupted index, out of bounds { let mut shares_bad_indices = shares.clone(); + let multi_open_domain_size = advz.multi_open_domain(common.multiplicity).unwrap().size(); for i in 0..shares_bad_indices.len() { - shares_bad_indices[i].index += u32::try_from(advz.multi_open_domain.size()).unwrap(); + shares_bad_indices[i].index += u32::try_from(multi_open_domain_size).unwrap(); advz.recover_payload(&shares_bad_indices, &common) .expect_err("recover_payload should fail when indices are out of bounds"); } @@ -248,10 +248,10 @@ fn verify_share_with_multiplicity() { let advz_params = AdvzParams { recovery_threshold: 16, num_storage_nodes: 20, - multiplicity: 4, + max_multiplicity: 4, payload_len: 4000, }; - let (mut advz, payload) = advz_init_with(advz_params); + let (mut advz, payload) = advz_init_with::(advz_params); let disperse = advz.disperse(payload).unwrap(); let (shares, common, commit) = (disperse.shares, disperse.common, disperse.commit); @@ -267,10 +267,10 @@ fn sad_path_verify_share_with_multiplicity() { let advz_params = AdvzParams { recovery_threshold: 16, num_storage_nodes: 20, - multiplicity: 32, // payload fitting into a single polynomial + max_multiplicity: 32, // payload fitting into a single polynomial payload_len: 8200, }; - let (mut advz, payload) = advz_init_with(advz_params); + let (mut advz, payload) = advz_init_with::(advz_params); let disperse = advz.disperse(payload).unwrap(); let (shares, common, commit) = (disperse.shares, disperse.common, disperse.commit); @@ -359,10 +359,101 @@ fn verify_share_with_different_multiplicity_helper( } } +#[test] +fn max_multiplicity() { + // regression test for https://github.com/EspressoSystems/jellyfish/issues/663 + + // play with these items + let num_storage_nodes = 6; + let recovery_threshold = 4; + let max_multiplicity = 1 << 5; // intentionally large so as to fit many payload sizes into a single polynomial + + let payload_byte_lens = [0, 1, 100, 10_000]; + type E = Bn254; + + // more items as a function of the above + let (mut advz, payload_bytes) = advz_init_with::(AdvzParams { + recovery_threshold, + num_storage_nodes, + max_multiplicity, + payload_len: *payload_byte_lens.iter().max().unwrap(), + }); + let elem_byte_len = bytes_to_field::elem_byte_capacity::<::ScalarField>(); + let (mut found_small_payload, mut found_large_payload) = (false, false); + + for payload_byte_len in payload_byte_lens { + let payload = &payload_bytes[..payload_byte_len]; + let num_payload_elems = payload_byte_len.div_ceil(elem_byte_len) as u32; + + let disperse = advz.disperse(payload).unwrap(); + let (shares, common, commit) = (disperse.shares, disperse.common, disperse.commit); + + // test: multiplicity set correctly + assert!( + common.multiplicity <= max_multiplicity, + "derived multiplicity should never exceed max_multiplicity" + ); + if num_payload_elems < max_multiplicity * recovery_threshold { + // small payload + found_small_payload = true; + assert!( + num_payload_elems <= common.multiplicity * advz.recovery_threshold, + "derived multiplicity too small" + ); + + if num_payload_elems > 0 { + // TODO TEMPORARY: enforce power-of-2 + // https://github.com/EspressoSystems/jellyfish/issues/668 + // + // After this issue is fixed the following test should use + // `common.multiplicity - 1` instead of `common.multiplicity / 2`. + assert!( + num_payload_elems > common.multiplicity / 2 * advz.recovery_threshold, + "derived multiplicity too large: payload_byte_len {}, common.multiplicity {}", + payload_byte_len, + common.multiplicity + ); + } else { + assert_eq!( + common.multiplicity, 1, + "zero-length payload should have multiplicity 1, found {}", + common.multiplicity + ); + } + + assert!( + common.poly_commits.len() <= 1, + "small payload should fit into a single polynomial" + ); + } else { + // large payload + found_large_payload = true; + assert_eq!( + common.multiplicity, max_multiplicity, + "derived multiplicity should equal max_multiplicity for large payload" + ); + } + + // sanity: recover payload + let bytes_recovered = advz.recover_payload(&shares, &common).unwrap(); + assert_eq!(bytes_recovered, payload); + + // sanity: verify shares + for share in shares { + advz.verify_share(&share, &common, &commit) + .unwrap() + .unwrap(); + } + } + + assert!(found_large_payload, "missing test for large payload"); + assert!(found_small_payload, "missing test for small payload"); +} + struct AdvzParams { recovery_threshold: u32, num_storage_nodes: u32, - multiplicity: u32, + max_multiplicity: u32, payload_len: usize, } @@ -371,29 +462,29 @@ struct AdvzParams { /// Returns the following tuple: /// 1. An initialized [`Advz`] instance. /// 2. A `Vec` filled with random bytes. -pub(super) fn advz_init() -> (Advz, Vec) { +pub(super) fn advz_init() -> (Advz, Vec) { let advz_params = AdvzParams { recovery_threshold: 16, num_storage_nodes: 20, - multiplicity: 1, + max_multiplicity: 1, payload_len: 4000, }; advz_init_with(advz_params) } -fn advz_init_with(advz_params: AdvzParams) -> (Advz, Vec) { +fn advz_init_with(advz_params: AdvzParams) -> (Advz, Vec) { let mut rng = jf_utils::test_rng(); - let poly_len = advz_params.recovery_threshold * advz_params.multiplicity; + let poly_len = advz_params.recovery_threshold * advz_params.max_multiplicity; let srs = init_srs(poly_len as usize, &mut rng); assert_ne!( - advz_params.multiplicity, 0, + advz_params.max_multiplicity, 0, "multiplicity should not be zero" ); - let advz = if advz_params.multiplicity > 1 { + let advz = if advz_params.max_multiplicity > 1 { Advz::with_multiplicity( advz_params.num_storage_nodes, advz_params.recovery_threshold, - advz_params.multiplicity, + advz_params.max_multiplicity, srs, ) .unwrap() diff --git a/vid/tests/vid/mod.rs b/vid/tests/vid/mod.rs index cf8c3a48a..fb3fe7c55 100644 --- a/vid/tests/vid/mod.rs +++ b/vid/tests/vid/mod.rs @@ -8,28 +8,32 @@ use jf_vid::{VidError, VidResult, VidScheme}; /// Correctness test generic over anything that impls [`VidScheme`] /// -/// `pub` visibility, but it's not part of this crate's public API -/// because it's in an integration test. +/// TODO this test should not have a `max_multiplicities` arg. It is intended to +/// be generic over the [`VidScheme`] and a generic VID scheme does not have a +/// multiplicity arg. +/// +/// `pub` visibility, but it's not part of this crate's public +/// API because it's in an integration test. /// pub fn round_trip( vid_factory: impl Fn(u32, u32, u32) -> V, vid_sizes: &[(u32, u32)], - multiplicities: &[u32], + max_multiplicities: &[u32], payload_byte_lens: &[u32], rng: &mut R, ) where V: VidScheme, R: RngCore + CryptoRng, { - for (&mult, &(recovery_threshold, num_storage_nodes)) in - zip(multiplicities.iter().cycle(), vid_sizes) + for (&max_multiplicity, &(recovery_threshold, num_storage_nodes)) in + zip(max_multiplicities.iter().cycle(), vid_sizes) { - let mut vid = vid_factory(recovery_threshold, num_storage_nodes, mult); + let mut vid = vid_factory(recovery_threshold, num_storage_nodes, max_multiplicity); for &len in payload_byte_lens { println!( - "m: {} n: {} mult: {} byte_len: {}", - recovery_threshold, num_storage_nodes, mult, len + "m: {} n: {} byte_len: {} max_mult: {}", + recovery_threshold, num_storage_nodes, len, max_multiplicity ); let bytes_random = { @@ -43,7 +47,7 @@ pub fn round_trip( assert_eq!(shares.len(), num_storage_nodes as usize); assert_eq!(commit, vid.commit_only(&bytes_random).unwrap()); assert_eq!(len, V::get_payload_byte_len(&common)); - assert_eq!(mult, V::get_multiplicity(&common)); + assert!(V::get_multiplicity(&common) <= max_multiplicity); assert_eq!(num_storage_nodes, V::get_num_storage_nodes(&common)); for share in shares.iter() {