From 057953e79e345021c273157c7ea8facbbae8b7e1 Mon Sep 17 00:00:00 2001 From: Jacob Kaufmann Date: Thu, 7 Mar 2024 14:02:59 -0700 Subject: [PATCH] feat: add (de)compress trait in bls module (#22) --- benches/kzg.rs | 25 ++++---- src/blob.rs | 7 ++- src/bls.rs | 153 ++++++++++++++++++++++++++++++++--------------- src/kzg/setup.rs | 42 +++++++------ src/lib.rs | 1 + 5 files changed, 150 insertions(+), 78 deletions(-) diff --git a/benches/kzg.rs b/benches/kzg.rs index 512dcca..4e7d28a 100644 --- a/benches/kzg.rs +++ b/benches/kzg.rs @@ -1,6 +1,7 @@ use kateth::{ blob::Blob, - kzg::{Bytes48, Setup}, + kzg::{Commitment, Proof, Setup}, + Compress, }; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput}; @@ -17,15 +18,19 @@ pub fn benchmark(c: &mut Criterion) { let blobs: Vec> = (0..max_batch_size) .map(|_| Blob::<4096>::random(&mut rng).to_bytes()) .collect(); - let commitments: Vec = blobs - .iter() - .map(|blob| kzg.blob_to_commitment(blob).unwrap().serialize()) - .collect(); - let proofs: Vec = blobs - .iter() - .zip(commitments.iter()) - .map(|(blob, commitment)| kzg.blob_proof(blob, commitment).unwrap().serialize()) - .collect(); + let mut commitments = Vec::with_capacity(blobs.len()); + let mut proofs = Vec::with_capacity(blobs.len()); + for blob in &blobs { + let commitment = kzg.blob_to_commitment(blob).unwrap(); + let mut bytes = [0u8; Commitment::COMPRESSED]; + commitment.compress(&mut bytes).unwrap(); + commitments.push(bytes); + + let proof = kzg.blob_proof(blob, &bytes).unwrap(); + let mut bytes = [0u8; Proof::COMPRESSED]; + proof.compress(&mut bytes).unwrap(); + proofs.push(bytes); + } c.bench_function("blob to kzg commitment", |b| { b.iter(|| kzg.blob_to_commitment(&blobs[0])) diff --git a/src/blob.rs b/src/blob.rs index 065ee06..dc05f80 100644 --- a/src/blob.rs +++ b/src/blob.rs @@ -1,5 +1,5 @@ use crate::{ - bls::{FiniteFieldError, Fr, P1}, + bls::{Compress, FiniteFieldError, Fr, P1}, kzg::{Commitment, Polynomial, Proof, Setup}, }; @@ -79,7 +79,10 @@ impl Blob { const DOMAIN: &[u8; 16] = b"FSBLOBVERIFY_V1_"; let degree = (N as u128).to_be_bytes(); - let comm = commitment.serialize(); + let mut comm = [0u8; Commitment::COMPRESSED]; + commitment + .compress(comm.as_mut_slice()) + .expect("sufficient buffer len"); let mut data = Vec::with_capacity(8 + 16 + Commitment::BYTES + Self::BYTES); data.extend_from_slice(DOMAIN); diff --git a/src/bls.rs b/src/bls.rs index 274d0ce..82b7090 100644 --- a/src/bls.rs +++ b/src/bls.rs @@ -10,11 +10,12 @@ use blst::{ blst_fr, blst_fr_add, blst_fr_cneg, blst_fr_eucl_inverse, blst_fr_from_scalar, blst_fr_from_uint64, blst_fr_lshift, blst_fr_mul, blst_fr_rshift, blst_fr_sub, blst_lendian_from_scalar, blst_miller_loop, blst_p1, blst_p1_add, blst_p1_affine, - blst_p1_affine_in_g1, blst_p1_cneg, blst_p1_compress, blst_p1_deserialize, blst_p1_from_affine, - blst_p1_mult, blst_p1_to_affine, blst_p2, blst_p2_add, blst_p2_affine, blst_p2_affine_in_g2, - blst_p2_deserialize, blst_p2_from_affine, blst_p2_mult, blst_p2_to_affine, blst_scalar, - blst_scalar_fr_check, blst_scalar_from_bendian, blst_scalar_from_fr, blst_sha256, - blst_uint64_from_fr, p1_affines, BLS12_381_G2, BLS12_381_NEG_G1, BLS12_381_NEG_G2, BLST_ERROR, + blst_p1_affine_in_g1, blst_p1_cneg, blst_p1_compress, blst_p1_from_affine, blst_p1_mult, + blst_p1_to_affine, blst_p1_uncompress, blst_p2, blst_p2_add, blst_p2_affine, + blst_p2_affine_in_g2, blst_p2_compress, blst_p2_from_affine, blst_p2_mult, blst_p2_to_affine, + blst_p2_uncompress, blst_scalar, blst_scalar_fr_check, blst_scalar_from_bendian, + blst_scalar_from_fr, blst_sha256, blst_uint64_from_fr, p1_affines, BLS12_381_G2, + BLS12_381_NEG_G1, BLS12_381_NEG_G2, BLST_ERROR, }; #[derive(Clone, Copy, Debug)] @@ -48,6 +49,32 @@ impl From for Error { } } +/// A data structure that can be serialized into the compressed format defined by Zcash. +/// +/// github.com/zkcrypto/pairing/blob/34aa52b0f7bef705917252ea63e5a13fa01af551/src/bls12_381/README.md +pub trait Compress { + /// The length in bytes of the compressed representation of `self`. + const COMPRESSED: usize; + + /// Compresses `self` into `buf`. + /// + /// # Errors + /// + /// Compression will fail if the length of `buf` is less than `Self::COMPRESSED`. + fn compress(&self, buf: impl AsMut<[u8]>) -> Result<(), &'static str>; +} + +/// A data structure that can be deserialized from the compressed format defined by Zcash. +/// +/// github.com/zkcrypto/pairing/blob/34aa52b0f7bef705917252ea63e5a13fa01af551/src/bls12_381/README.md +pub trait Decompress: Sized { + /// The error that can occur upon decompression. + type Error; + + /// Decompresses `compressed` into `Self`. + fn decompress(compressed: impl AsRef<[u8]>) -> Result; +} + #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] pub struct Fr { element: blst_fr, @@ -355,37 +382,6 @@ impl P1 { pub const BITS: usize = 384; pub const BYTES: usize = Self::BITS / 8; - pub fn deserialize(bytes: &[u8; Self::BYTES]) -> Result { - let mut affine = MaybeUninit::::uninit(); - let mut out = MaybeUninit::::uninit(); - unsafe { - // NOTE: deserialize performs a curve check but not a subgroup check. if that changes, - // then we should encounter `unreachable` for `BLST_POINT_NOT_IN_GROUP` in tests. - match blst_p1_deserialize(affine.as_mut_ptr(), bytes.as_ptr()) { - BLST_ERROR::BLST_SUCCESS => {} - BLST_ERROR::BLST_BAD_ENCODING => return Err(ECGroupError::InvalidEncoding), - BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(ECGroupError::NotOnCurve), - other => unreachable!("{other:?}"), - } - if !blst_p1_affine_in_g1(affine.as_ptr()) { - return Err(ECGroupError::NotInGroup); - } - - blst_p1_from_affine(out.as_mut_ptr(), affine.as_ptr()); - Ok(Self { - element: out.assume_init(), - }) - } - } - - pub fn serialize(&self) -> [u8; Self::BYTES] { - let mut out = [0; Self::BYTES]; - unsafe { - blst_p1_compress(out.as_mut_ptr(), &self.element); - } - out - } - pub fn lincomb(points: impl AsRef<[Self]>, scalars: impl AsRef<[Fr]>) -> Self { let n = cmp::min(points.as_ref().len(), scalars.as_ref().len()); let mut lincomb = Self::INF; @@ -506,37 +502,55 @@ impl Neg for P1 { } } -#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] -pub struct P2 { - element: blst_p2, +impl Compress for P1 { + const COMPRESSED: usize = Self::BYTES; + + fn compress(&self, mut buf: impl AsMut<[u8]>) -> Result<(), &'static str> { + if buf.as_mut().len() < Self::COMPRESSED { + return Err("insufficient buffer length"); + } + unsafe { + blst_p1_compress(buf.as_mut().as_mut_ptr(), &self.element); + } + Ok(()) + } } -impl P2 { - pub const BITS: usize = 768; - pub const BYTES: usize = Self::BITS / 8; +impl Decompress for P1 { + type Error = ECGroupError; - pub fn deserialize(bytes: &[u8; Self::BYTES]) -> Result { - let mut affine = MaybeUninit::::uninit(); - let mut out = MaybeUninit::::uninit(); + fn decompress(compressed: impl AsRef<[u8]>) -> Result { + let mut affine = MaybeUninit::::uninit(); + let mut out = MaybeUninit::::uninit(); unsafe { - // NOTE: deserialize performs a curve check but not a subgroup check. if that changes, + // NOTE: uncompress performs a curve check but not a subgroup check. if that changes, // then we should encounter `unreachable` for `BLST_POINT_NOT_IN_GROUP` in tests. - match blst_p2_deserialize(affine.as_mut_ptr(), bytes.as_ref().as_ptr()) { + match blst_p1_uncompress(affine.as_mut_ptr(), compressed.as_ref().as_ptr()) { BLST_ERROR::BLST_SUCCESS => {} BLST_ERROR::BLST_BAD_ENCODING => return Err(ECGroupError::InvalidEncoding), BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(ECGroupError::NotOnCurve), other => unreachable!("{other:?}"), } - if !blst_p2_affine_in_g2(affine.as_ptr()) { + if !blst_p1_affine_in_g1(affine.as_ptr()) { return Err(ECGroupError::NotInGroup); } - blst_p2_from_affine(out.as_mut_ptr(), affine.as_ptr()); + blst_p1_from_affine(out.as_mut_ptr(), affine.as_ptr()); Ok(Self { element: out.assume_init(), }) } } +} + +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] +pub struct P2 { + element: blst_p2, +} + +impl P2 { + pub const BITS: usize = 768; + pub const BYTES: usize = Self::BITS / 8; // TODO: make available as `const` pub fn generator() -> Self { @@ -608,6 +622,47 @@ impl Mul for P2 { } } +impl Compress for P2 { + const COMPRESSED: usize = Self::BYTES; + + fn compress(&self, mut buf: impl AsMut<[u8]>) -> Result<(), &'static str> { + if buf.as_mut().len() < Self::COMPRESSED { + return Err("insufficient buffer length"); + } + unsafe { + blst_p2_compress(buf.as_mut().as_mut_ptr(), &self.element); + } + Ok(()) + } +} + +impl Decompress for P2 { + type Error = ECGroupError; + + fn decompress(compressed: impl AsRef<[u8]>) -> Result { + let mut affine = MaybeUninit::::uninit(); + let mut out = MaybeUninit::::uninit(); + unsafe { + // NOTE: uncompress performs a curve check but not a subgroup check. if that changes, + // then we should encounter `unreachable` for `BLST_POINT_NOT_IN_GROUP` in tests. + match blst_p2_uncompress(affine.as_mut_ptr(), compressed.as_ref().as_ptr()) { + BLST_ERROR::BLST_SUCCESS => {} + BLST_ERROR::BLST_BAD_ENCODING => return Err(ECGroupError::InvalidEncoding), + BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(ECGroupError::NotOnCurve), + other => unreachable!("{other:?}"), + } + if !blst_p2_affine_in_g2(affine.as_ptr()) { + return Err(ECGroupError::NotInGroup); + } + + blst_p2_from_affine(out.as_mut_ptr(), affine.as_ptr()); + Ok(Self { + element: out.assume_init(), + }) + } + } +} + pub fn verify_pairings((a1, a2): (P1, P2), (b1, b2): (P1, P2)) -> bool { let mut a1_neg_affine = MaybeUninit::::uninit(); let mut a2_affine = MaybeUninit::::uninit(); diff --git a/src/kzg/setup.rs b/src/kzg/setup.rs index 999cecb..2e97e99 100644 --- a/src/kzg/setup.rs +++ b/src/kzg/setup.rs @@ -7,7 +7,7 @@ use std::{ use super::{Bytes32, Bytes48, Commitment, Error, Polynomial, Proof}; use crate::{ blob::{Blob, Error as BlobError}, - bls::{self, ECGroupError, Error as BlsError, Fr, P1, P2}, + bls::{self, Decompress, ECGroupError, Error as BlsError, Fr, P1, P2}, math, }; @@ -60,7 +60,7 @@ impl Setup { // TODO: skip unnecessary allocation let point = FixedBytes::<48>::from_slice(point); let point = - P1::deserialize(&point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?; + P1::decompress(point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?; g1_lagrange[i] = point; } let g1_lagrange_brp = math::bit_reversal_permutation_boxed_array(g1_lagrange.as_slice()); @@ -75,7 +75,7 @@ impl Setup { // TODO: skip unnecessary allocation let point = FixedBytes::<96>::from_slice(point); let point = - P2::deserialize(&point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?; + P2::decompress(point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?; g2_monomial[i] = point; } @@ -108,8 +108,8 @@ impl Setup { point: &Bytes32, eval: &Bytes32, ) -> Result { - let proof = Proof::deserialize(proof).map_err(|err| Error::from(BlsError::ECGroup(err)))?; - let commitment = Commitment::deserialize(commitment) + let proof = Proof::decompress(proof).map_err(|err| Error::from(BlsError::ECGroup(err)))?; + let commitment = Commitment::decompress(commitment) .map_err(|err| Error::from(BlsError::ECGroup(err)))?; let point = Fr::from_be_slice(point).map_err(|err| Error::from(BlsError::FiniteField(err)))?; @@ -184,7 +184,7 @@ impl Setup { pub fn blob_proof(&self, blob: impl AsRef<[u8]>, commitment: &Bytes48) -> Result { let blob: Blob = Blob::from_slice(blob).map_err(Error::from)?; - let commitment = Commitment::deserialize(commitment) + let commitment = Commitment::decompress(commitment) .map_err(|err| Error::from(BlsError::ECGroup(err)))?; let proof = self.blob_proof_inner(&blob, &commitment); Ok(proof) @@ -220,9 +220,9 @@ impl Setup { proof: &Bytes48, ) -> Result { let blob: Blob = Blob::from_slice(blob).map_err(Error::from)?; - let commitment = Commitment::deserialize(commitment) + let commitment = Commitment::decompress(commitment) .map_err(|err| Error::from(BlsError::ECGroup(err)))?; - let proof = Proof::deserialize(proof).map_err(|err| Error::from(BlsError::ECGroup(err)))?; + let proof = Proof::decompress(proof).map_err(|err| Error::from(BlsError::ECGroup(err)))?; let verified = self.verify_blob_proof_inner(&blob, &commitment, &proof); Ok(verified) @@ -271,12 +271,11 @@ impl Setup { let commitments: Result, _> = commitments .as_ref() .iter() - .map(Commitment::deserialize) + .map(Commitment::decompress) .collect(); let commitments = commitments.map_err(|err| Error::from(BlsError::ECGroup(err)))?; - let proofs: Result, _> = - proofs.as_ref().iter().map(Proof::deserialize).collect(); + let proofs: Result, _> = proofs.as_ref().iter().map(Proof::decompress).collect(); let proofs = proofs.map_err(|err| Error::from(BlsError::ECGroup(err)))?; let verified = self.verify_blob_proof_batch_inner(blobs, commitments, proofs); @@ -288,9 +287,12 @@ impl Setup { mod tests { use super::*; - use crate::kzg::spec::{ - BlobToCommitment, ComputeBlobProof, ComputeProof, VerifyBlobProof, VerifyBlobProofBatch, - VerifyProof, + use crate::{ + bls::Compress, + kzg::spec::{ + BlobToCommitment, ComputeBlobProof, ComputeProof, VerifyBlobProof, + VerifyBlobProofBatch, VerifyProof, + }, }; use std::{ @@ -343,7 +345,9 @@ mod tests { continue; }; let (expected_proof, expected_y) = expected.unwrap(); - assert_eq!(proof.serialize(), expected_proof); + let mut proof_bytes = [0u8; Proof::COMPRESSED]; + proof.compress(&mut proof_bytes).unwrap(); + assert_eq!(proof_bytes, expected_proof); assert_eq!(y.to_be_bytes(), expected_y); } } @@ -367,7 +371,9 @@ mod tests { continue; }; let expected = expected.unwrap(); - assert_eq!(proof.serialize(), expected); + let mut proof_bytes = [0u8; Proof::COMPRESSED]; + proof.compress(&mut proof_bytes).unwrap(); + assert_eq!(proof_bytes, expected); } } @@ -387,7 +393,9 @@ mod tests { continue; }; let expected = expected.unwrap(); - assert_eq!(commitment.serialize(), expected); + let mut commitment_bytes = [0u8; Commitment::COMPRESSED]; + commitment.compress(&mut commitment_bytes).unwrap(); + assert_eq!(commitment_bytes, expected); } } diff --git a/src/lib.rs b/src/lib.rs index 9ed498a..611fc29 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ mod bls; mod math; +pub use bls::{Compress, Decompress}; pub mod blob; pub mod kzg;