Skip to content

Commit

Permalink
feat: add (de)compress trait in bls module (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobkaufmann authored Mar 7, 2024
1 parent 2eb5170 commit 057953e
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 78 deletions.
25 changes: 15 additions & 10 deletions benches/kzg.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use kateth::{
blob::Blob,
kzg::{Bytes48, Setup},
kzg::{Commitment, Proof, Setup},
Compress,
};

use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput};
Expand All @@ -17,15 +18,19 @@ pub fn benchmark(c: &mut Criterion) {
let blobs: Vec<Vec<u8>> = (0..max_batch_size)
.map(|_| Blob::<4096>::random(&mut rng).to_bytes())
.collect();
let commitments: Vec<Bytes48> = blobs
.iter()
.map(|blob| kzg.blob_to_commitment(blob).unwrap().serialize())
.collect();
let proofs: Vec<Bytes48> = blobs
.iter()
.zip(commitments.iter())
.map(|(blob, commitment)| kzg.blob_proof(blob, commitment).unwrap().serialize())
.collect();
let mut commitments = Vec::with_capacity(blobs.len());
let mut proofs = Vec::with_capacity(blobs.len());
for blob in &blobs {
let commitment = kzg.blob_to_commitment(blob).unwrap();
let mut bytes = [0u8; Commitment::COMPRESSED];
commitment.compress(&mut bytes).unwrap();
commitments.push(bytes);

let proof = kzg.blob_proof(blob, &bytes).unwrap();
let mut bytes = [0u8; Proof::COMPRESSED];
proof.compress(&mut bytes).unwrap();
proofs.push(bytes);
}

c.bench_function("blob to kzg commitment", |b| {
b.iter(|| kzg.blob_to_commitment(&blobs[0]))
Expand Down
7 changes: 5 additions & 2 deletions src/blob.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
bls::{FiniteFieldError, Fr, P1},
bls::{Compress, FiniteFieldError, Fr, P1},
kzg::{Commitment, Polynomial, Proof, Setup},
};

Expand Down Expand Up @@ -79,7 +79,10 @@ impl<const N: usize> Blob<N> {
const DOMAIN: &[u8; 16] = b"FSBLOBVERIFY_V1_";
let degree = (N as u128).to_be_bytes();

let comm = commitment.serialize();
let mut comm = [0u8; Commitment::COMPRESSED];
commitment
.compress(comm.as_mut_slice())
.expect("sufficient buffer len");

let mut data = Vec::with_capacity(8 + 16 + Commitment::BYTES + Self::BYTES);
data.extend_from_slice(DOMAIN);
Expand Down
153 changes: 104 additions & 49 deletions src/bls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ use blst::{
blst_fr, blst_fr_add, blst_fr_cneg, blst_fr_eucl_inverse, blst_fr_from_scalar,
blst_fr_from_uint64, blst_fr_lshift, blst_fr_mul, blst_fr_rshift, blst_fr_sub,
blst_lendian_from_scalar, blst_miller_loop, blst_p1, blst_p1_add, blst_p1_affine,
blst_p1_affine_in_g1, blst_p1_cneg, blst_p1_compress, blst_p1_deserialize, blst_p1_from_affine,
blst_p1_mult, blst_p1_to_affine, blst_p2, blst_p2_add, blst_p2_affine, blst_p2_affine_in_g2,
blst_p2_deserialize, blst_p2_from_affine, blst_p2_mult, blst_p2_to_affine, blst_scalar,
blst_scalar_fr_check, blst_scalar_from_bendian, blst_scalar_from_fr, blst_sha256,
blst_uint64_from_fr, p1_affines, BLS12_381_G2, BLS12_381_NEG_G1, BLS12_381_NEG_G2, BLST_ERROR,
blst_p1_affine_in_g1, blst_p1_cneg, blst_p1_compress, blst_p1_from_affine, blst_p1_mult,
blst_p1_to_affine, blst_p1_uncompress, blst_p2, blst_p2_add, blst_p2_affine,
blst_p2_affine_in_g2, blst_p2_compress, blst_p2_from_affine, blst_p2_mult, blst_p2_to_affine,
blst_p2_uncompress, blst_scalar, blst_scalar_fr_check, blst_scalar_from_bendian,
blst_scalar_from_fr, blst_sha256, blst_uint64_from_fr, p1_affines, BLS12_381_G2,
BLS12_381_NEG_G1, BLS12_381_NEG_G2, BLST_ERROR,
};

#[derive(Clone, Copy, Debug)]
Expand Down Expand Up @@ -48,6 +49,32 @@ impl From<ECGroupError> for Error {
}
}

/// A data structure that can be serialized into the compressed format defined by Zcash.
///
/// github.com/zkcrypto/pairing/blob/34aa52b0f7bef705917252ea63e5a13fa01af551/src/bls12_381/README.md
pub trait Compress {
/// The length in bytes of the compressed representation of `self`.
const COMPRESSED: usize;

/// Compresses `self` into `buf`.
///
/// # Errors
///
/// Compression will fail if the length of `buf` is less than `Self::COMPRESSED`.
fn compress(&self, buf: impl AsMut<[u8]>) -> Result<(), &'static str>;
}

/// A data structure that can be deserialized from the compressed format defined by Zcash.
///
/// github.com/zkcrypto/pairing/blob/34aa52b0f7bef705917252ea63e5a13fa01af551/src/bls12_381/README.md
pub trait Decompress: Sized {
/// The error that can occur upon decompression.
type Error;

/// Decompresses `compressed` into `Self`.
fn decompress(compressed: impl AsRef<[u8]>) -> Result<Self, Self::Error>;
}

#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct Fr {
element: blst_fr,
Expand Down Expand Up @@ -355,37 +382,6 @@ impl P1 {
pub const BITS: usize = 384;
pub const BYTES: usize = Self::BITS / 8;

pub fn deserialize(bytes: &[u8; Self::BYTES]) -> Result<Self, ECGroupError> {
let mut affine = MaybeUninit::<blst_p1_affine>::uninit();
let mut out = MaybeUninit::<blst_p1>::uninit();
unsafe {
// NOTE: deserialize performs a curve check but not a subgroup check. if that changes,
// then we should encounter `unreachable` for `BLST_POINT_NOT_IN_GROUP` in tests.
match blst_p1_deserialize(affine.as_mut_ptr(), bytes.as_ptr()) {
BLST_ERROR::BLST_SUCCESS => {}
BLST_ERROR::BLST_BAD_ENCODING => return Err(ECGroupError::InvalidEncoding),
BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(ECGroupError::NotOnCurve),
other => unreachable!("{other:?}"),
}
if !blst_p1_affine_in_g1(affine.as_ptr()) {
return Err(ECGroupError::NotInGroup);
}

blst_p1_from_affine(out.as_mut_ptr(), affine.as_ptr());
Ok(Self {
element: out.assume_init(),
})
}
}

pub fn serialize(&self) -> [u8; Self::BYTES] {
let mut out = [0; Self::BYTES];
unsafe {
blst_p1_compress(out.as_mut_ptr(), &self.element);
}
out
}

pub fn lincomb(points: impl AsRef<[Self]>, scalars: impl AsRef<[Fr]>) -> Self {
let n = cmp::min(points.as_ref().len(), scalars.as_ref().len());
let mut lincomb = Self::INF;
Expand Down Expand Up @@ -506,37 +502,55 @@ impl Neg for P1 {
}
}

#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct P2 {
element: blst_p2,
impl Compress for P1 {
const COMPRESSED: usize = Self::BYTES;

fn compress(&self, mut buf: impl AsMut<[u8]>) -> Result<(), &'static str> {
if buf.as_mut().len() < Self::COMPRESSED {
return Err("insufficient buffer length");
}
unsafe {
blst_p1_compress(buf.as_mut().as_mut_ptr(), &self.element);
}
Ok(())
}
}

impl P2 {
pub const BITS: usize = 768;
pub const BYTES: usize = Self::BITS / 8;
impl Decompress for P1 {
type Error = ECGroupError;

pub fn deserialize(bytes: &[u8; Self::BYTES]) -> Result<Self, ECGroupError> {
let mut affine = MaybeUninit::<blst_p2_affine>::uninit();
let mut out = MaybeUninit::<blst_p2>::uninit();
fn decompress(compressed: impl AsRef<[u8]>) -> Result<Self, Self::Error> {
let mut affine = MaybeUninit::<blst_p1_affine>::uninit();
let mut out = MaybeUninit::<blst_p1>::uninit();
unsafe {
// NOTE: deserialize performs a curve check but not a subgroup check. if that changes,
// NOTE: uncompress performs a curve check but not a subgroup check. if that changes,
// then we should encounter `unreachable` for `BLST_POINT_NOT_IN_GROUP` in tests.
match blst_p2_deserialize(affine.as_mut_ptr(), bytes.as_ref().as_ptr()) {
match blst_p1_uncompress(affine.as_mut_ptr(), compressed.as_ref().as_ptr()) {
BLST_ERROR::BLST_SUCCESS => {}
BLST_ERROR::BLST_BAD_ENCODING => return Err(ECGroupError::InvalidEncoding),
BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(ECGroupError::NotOnCurve),
other => unreachable!("{other:?}"),
}
if !blst_p2_affine_in_g2(affine.as_ptr()) {
if !blst_p1_affine_in_g1(affine.as_ptr()) {
return Err(ECGroupError::NotInGroup);
}

blst_p2_from_affine(out.as_mut_ptr(), affine.as_ptr());
blst_p1_from_affine(out.as_mut_ptr(), affine.as_ptr());
Ok(Self {
element: out.assume_init(),
})
}
}
}

#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct P2 {
element: blst_p2,
}

impl P2 {
pub const BITS: usize = 768;
pub const BYTES: usize = Self::BITS / 8;

// TODO: make available as `const`
pub fn generator() -> Self {
Expand Down Expand Up @@ -608,6 +622,47 @@ impl Mul<Fr> for P2 {
}
}

impl Compress for P2 {
const COMPRESSED: usize = Self::BYTES;

fn compress(&self, mut buf: impl AsMut<[u8]>) -> Result<(), &'static str> {
if buf.as_mut().len() < Self::COMPRESSED {
return Err("insufficient buffer length");
}
unsafe {
blst_p2_compress(buf.as_mut().as_mut_ptr(), &self.element);
}
Ok(())
}
}

impl Decompress for P2 {
type Error = ECGroupError;

fn decompress(compressed: impl AsRef<[u8]>) -> Result<Self, Self::Error> {
let mut affine = MaybeUninit::<blst_p2_affine>::uninit();
let mut out = MaybeUninit::<blst_p2>::uninit();
unsafe {
// NOTE: uncompress performs a curve check but not a subgroup check. if that changes,
// then we should encounter `unreachable` for `BLST_POINT_NOT_IN_GROUP` in tests.
match blst_p2_uncompress(affine.as_mut_ptr(), compressed.as_ref().as_ptr()) {
BLST_ERROR::BLST_SUCCESS => {}
BLST_ERROR::BLST_BAD_ENCODING => return Err(ECGroupError::InvalidEncoding),
BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(ECGroupError::NotOnCurve),
other => unreachable!("{other:?}"),
}
if !blst_p2_affine_in_g2(affine.as_ptr()) {
return Err(ECGroupError::NotInGroup);
}

blst_p2_from_affine(out.as_mut_ptr(), affine.as_ptr());
Ok(Self {
element: out.assume_init(),
})
}
}
}

pub fn verify_pairings((a1, a2): (P1, P2), (b1, b2): (P1, P2)) -> bool {
let mut a1_neg_affine = MaybeUninit::<blst_p1_affine>::uninit();
let mut a2_affine = MaybeUninit::<blst_p2_affine>::uninit();
Expand Down
42 changes: 25 additions & 17 deletions src/kzg/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
use super::{Bytes32, Bytes48, Commitment, Error, Polynomial, Proof};
use crate::{
blob::{Blob, Error as BlobError},
bls::{self, ECGroupError, Error as BlsError, Fr, P1, P2},
bls::{self, Decompress, ECGroupError, Error as BlsError, Fr, P1, P2},
math,
};

Expand Down Expand Up @@ -60,7 +60,7 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
// TODO: skip unnecessary allocation
let point = FixedBytes::<48>::from_slice(point);
let point =
P1::deserialize(&point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?;
P1::decompress(point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?;
g1_lagrange[i] = point;
}
let g1_lagrange_brp = math::bit_reversal_permutation_boxed_array(g1_lagrange.as_slice());
Expand All @@ -75,7 +75,7 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
// TODO: skip unnecessary allocation
let point = FixedBytes::<96>::from_slice(point);
let point =
P2::deserialize(&point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?;
P2::decompress(point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?;
g2_monomial[i] = point;
}

Expand Down Expand Up @@ -108,8 +108,8 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
point: &Bytes32,
eval: &Bytes32,
) -> Result<bool, Error> {
let proof = Proof::deserialize(proof).map_err(|err| Error::from(BlsError::ECGroup(err)))?;
let commitment = Commitment::deserialize(commitment)
let proof = Proof::decompress(proof).map_err(|err| Error::from(BlsError::ECGroup(err)))?;
let commitment = Commitment::decompress(commitment)
.map_err(|err| Error::from(BlsError::ECGroup(err)))?;
let point =
Fr::from_be_slice(point).map_err(|err| Error::from(BlsError::FiniteField(err)))?;
Expand Down Expand Up @@ -184,7 +184,7 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {

pub fn blob_proof(&self, blob: impl AsRef<[u8]>, commitment: &Bytes48) -> Result<Proof, Error> {
let blob: Blob<G1> = Blob::from_slice(blob).map_err(Error::from)?;
let commitment = Commitment::deserialize(commitment)
let commitment = Commitment::decompress(commitment)
.map_err(|err| Error::from(BlsError::ECGroup(err)))?;
let proof = self.blob_proof_inner(&blob, &commitment);
Ok(proof)
Expand Down Expand Up @@ -220,9 +220,9 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
proof: &Bytes48,
) -> Result<bool, Error> {
let blob: Blob<G1> = Blob::from_slice(blob).map_err(Error::from)?;
let commitment = Commitment::deserialize(commitment)
let commitment = Commitment::decompress(commitment)
.map_err(|err| Error::from(BlsError::ECGroup(err)))?;
let proof = Proof::deserialize(proof).map_err(|err| Error::from(BlsError::ECGroup(err)))?;
let proof = Proof::decompress(proof).map_err(|err| Error::from(BlsError::ECGroup(err)))?;

let verified = self.verify_blob_proof_inner(&blob, &commitment, &proof);
Ok(verified)
Expand Down Expand Up @@ -271,12 +271,11 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
let commitments: Result<Vec<Commitment>, _> = commitments
.as_ref()
.iter()
.map(Commitment::deserialize)
.map(Commitment::decompress)
.collect();
let commitments = commitments.map_err(|err| Error::from(BlsError::ECGroup(err)))?;

let proofs: Result<Vec<Proof>, _> =
proofs.as_ref().iter().map(Proof::deserialize).collect();
let proofs: Result<Vec<Proof>, _> = proofs.as_ref().iter().map(Proof::decompress).collect();
let proofs = proofs.map_err(|err| Error::from(BlsError::ECGroup(err)))?;

let verified = self.verify_blob_proof_batch_inner(blobs, commitments, proofs);
Expand All @@ -288,9 +287,12 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
mod tests {
use super::*;

use crate::kzg::spec::{
BlobToCommitment, ComputeBlobProof, ComputeProof, VerifyBlobProof, VerifyBlobProofBatch,
VerifyProof,
use crate::{
bls::Compress,
kzg::spec::{
BlobToCommitment, ComputeBlobProof, ComputeProof, VerifyBlobProof,
VerifyBlobProofBatch, VerifyProof,
},
};

use std::{
Expand Down Expand Up @@ -343,7 +345,9 @@ mod tests {
continue;
};
let (expected_proof, expected_y) = expected.unwrap();
assert_eq!(proof.serialize(), expected_proof);
let mut proof_bytes = [0u8; Proof::COMPRESSED];
proof.compress(&mut proof_bytes).unwrap();
assert_eq!(proof_bytes, expected_proof);
assert_eq!(y.to_be_bytes(), expected_y);
}
}
Expand All @@ -367,7 +371,9 @@ mod tests {
continue;
};
let expected = expected.unwrap();
assert_eq!(proof.serialize(), expected);
let mut proof_bytes = [0u8; Proof::COMPRESSED];
proof.compress(&mut proof_bytes).unwrap();
assert_eq!(proof_bytes, expected);
}
}

Expand All @@ -387,7 +393,9 @@ mod tests {
continue;
};
let expected = expected.unwrap();
assert_eq!(commitment.serialize(), expected);
let mut commitment_bytes = [0u8; Commitment::COMPRESSED];
commitment.compress(&mut commitment_bytes).unwrap();
assert_eq!(commitment_bytes, expected);
}
}

Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod bls;
mod math;

pub use bls::{Compress, Decompress};
pub mod blob;
pub mod kzg;

0 comments on commit 057953e

Please sign in to comment.