Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add (de)compress trait in bls module #22

Merged
merged 9 commits into from
Mar 7, 2024
25 changes: 15 additions & 10 deletions benches/kzg.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use kateth::{
blob::Blob,
kzg::{Bytes48, Setup},
kzg::{Commitment, Proof, Setup},
Compress,
};

use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput};
Expand All @@ -17,15 +18,19 @@ pub fn benchmark(c: &mut Criterion) {
let blobs: Vec<Vec<u8>> = (0..max_batch_size)
.map(|_| Blob::<4096>::random(&mut rng).to_bytes())
.collect();
let commitments: Vec<Bytes48> = blobs
.iter()
.map(|blob| kzg.blob_to_commitment(blob).unwrap().serialize())
.collect();
let proofs: Vec<Bytes48> = blobs
.iter()
.zip(commitments.iter())
.map(|(blob, commitment)| kzg.blob_proof(blob, commitment).unwrap().serialize())
.collect();
let mut commitments = Vec::with_capacity(blobs.len());
let mut proofs = Vec::with_capacity(blobs.len());
for blob in &blobs {
let commitment = kzg.blob_to_commitment(blob).unwrap();
let mut bytes = [0u8; Commitment::COMPRESSED];
commitment.compress(&mut bytes).unwrap();
commitments.push(bytes);

let proof = kzg.blob_proof(blob, &bytes).unwrap();
let mut bytes = [0u8; Proof::COMPRESSED];
proof.compress(&mut bytes).unwrap();
proofs.push(bytes);
}

c.bench_function("blob to kzg commitment", |b| {
b.iter(|| kzg.blob_to_commitment(&blobs[0]))
Expand Down
7 changes: 5 additions & 2 deletions src/blob.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
bls::{FiniteFieldError, Fr, P1},
bls::{Compress, FiniteFieldError, Fr, P1},
kzg::{Commitment, Polynomial, Proof, Setup},
};

Expand Down Expand Up @@ -79,7 +79,10 @@ impl<const N: usize> Blob<N> {
const DOMAIN: &[u8; 16] = b"FSBLOBVERIFY_V1_";
let degree = (N as u128).to_be_bytes();

let comm = commitment.serialize();
let mut comm = [0u8; Commitment::COMPRESSED];
commitment
.compress(comm.as_mut_slice())
.expect("sufficient buffer len");

let mut data = Vec::with_capacity(8 + 16 + Commitment::BYTES + Self::BYTES);
data.extend_from_slice(DOMAIN);
Expand Down
153 changes: 104 additions & 49 deletions src/bls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ use blst::{
blst_fr, blst_fr_add, blst_fr_cneg, blst_fr_eucl_inverse, blst_fr_from_scalar,
blst_fr_from_uint64, blst_fr_lshift, blst_fr_mul, blst_fr_rshift, blst_fr_sub,
blst_lendian_from_scalar, blst_miller_loop, blst_p1, blst_p1_add, blst_p1_affine,
blst_p1_affine_in_g1, blst_p1_cneg, blst_p1_compress, blst_p1_deserialize, blst_p1_from_affine,
blst_p1_mult, blst_p1_to_affine, blst_p2, blst_p2_add, blst_p2_affine, blst_p2_affine_in_g2,
blst_p2_deserialize, blst_p2_from_affine, blst_p2_mult, blst_p2_to_affine, blst_scalar,
blst_scalar_fr_check, blst_scalar_from_bendian, blst_scalar_from_fr, blst_sha256,
blst_uint64_from_fr, p1_affines, BLS12_381_G2, BLS12_381_NEG_G1, BLS12_381_NEG_G2, BLST_ERROR,
blst_p1_affine_in_g1, blst_p1_cneg, blst_p1_compress, blst_p1_from_affine, blst_p1_mult,
blst_p1_to_affine, blst_p1_uncompress, blst_p2, blst_p2_add, blst_p2_affine,
blst_p2_affine_in_g2, blst_p2_compress, blst_p2_from_affine, blst_p2_mult, blst_p2_to_affine,
blst_p2_uncompress, blst_scalar, blst_scalar_fr_check, blst_scalar_from_bendian,
blst_scalar_from_fr, blst_sha256, blst_uint64_from_fr, p1_affines, BLS12_381_G2,
BLS12_381_NEG_G1, BLS12_381_NEG_G2, BLST_ERROR,
};

#[derive(Clone, Copy, Debug)]
Expand Down Expand Up @@ -48,6 +49,32 @@ impl From<ECGroupError> for Error {
}
}

/// A data structure that can be serialized into the compressed format defined by Zcash.
///
/// github.com/zkcrypto/pairing/blob/34aa52b0f7bef705917252ea63e5a13fa01af551/src/bls12_381/README.md
pub trait Compress {
/// The length in bytes of the compressed representation of `self`.
const COMPRESSED: usize;

/// Compresses `self` into `buf`.
///
/// # Errors
///
/// Compression will fail if the length of `buf` is less than `Self::COMPRESSED`.
fn compress(&self, buf: impl AsMut<[u8]>) -> Result<(), &'static str>;
}

/// A data structure that can be deserialized from the compressed format defined by Zcash.
///
/// github.com/zkcrypto/pairing/blob/34aa52b0f7bef705917252ea63e5a13fa01af551/src/bls12_381/README.md
pub trait Decompress: Sized {
/// The error that can occur upon decompression.
type Error;

/// Decompresses `compressed` into `Self`.
fn decompress(compressed: impl AsRef<[u8]>) -> Result<Self, Self::Error>;
}

#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct Fr {
element: blst_fr,
Expand Down Expand Up @@ -355,37 +382,6 @@ impl P1 {
pub const BITS: usize = 384;
pub const BYTES: usize = Self::BITS / 8;

pub fn deserialize(bytes: &[u8; Self::BYTES]) -> Result<Self, ECGroupError> {
let mut affine = MaybeUninit::<blst_p1_affine>::uninit();
let mut out = MaybeUninit::<blst_p1>::uninit();
unsafe {
// NOTE: deserialize performs a curve check but not a subgroup check. if that changes,
// then we should encounter `unreachable` for `BLST_POINT_NOT_IN_GROUP` in tests.
match blst_p1_deserialize(affine.as_mut_ptr(), bytes.as_ptr()) {
BLST_ERROR::BLST_SUCCESS => {}
BLST_ERROR::BLST_BAD_ENCODING => return Err(ECGroupError::InvalidEncoding),
BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(ECGroupError::NotOnCurve),
other => unreachable!("{other:?}"),
}
if !blst_p1_affine_in_g1(affine.as_ptr()) {
return Err(ECGroupError::NotInGroup);
}

blst_p1_from_affine(out.as_mut_ptr(), affine.as_ptr());
Ok(Self {
element: out.assume_init(),
})
}
}

pub fn serialize(&self) -> [u8; Self::BYTES] {
let mut out = [0; Self::BYTES];
unsafe {
blst_p1_compress(out.as_mut_ptr(), &self.element);
}
out
}

pub fn lincomb(points: impl AsRef<[Self]>, scalars: impl AsRef<[Fr]>) -> Self {
let n = cmp::min(points.as_ref().len(), scalars.as_ref().len());
let mut lincomb = Self::INF;
Expand Down Expand Up @@ -506,37 +502,55 @@ impl Neg for P1 {
}
}

#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct P2 {
element: blst_p2,
impl Compress for P1 {
const COMPRESSED: usize = Self::BYTES;

fn compress(&self, mut buf: impl AsMut<[u8]>) -> Result<(), &'static str> {
if buf.as_mut().len() < Self::COMPRESSED {
return Err("insufficient buffer length");
}
unsafe {
blst_p1_compress(buf.as_mut().as_mut_ptr(), &self.element);
}
Ok(())
}
}

impl P2 {
pub const BITS: usize = 768;
pub const BYTES: usize = Self::BITS / 8;
impl Decompress for P1 {
type Error = ECGroupError;

pub fn deserialize(bytes: &[u8; Self::BYTES]) -> Result<Self, ECGroupError> {
let mut affine = MaybeUninit::<blst_p2_affine>::uninit();
let mut out = MaybeUninit::<blst_p2>::uninit();
fn decompress(compressed: impl AsRef<[u8]>) -> Result<Self, Self::Error> {
let mut affine = MaybeUninit::<blst_p1_affine>::uninit();
let mut out = MaybeUninit::<blst_p1>::uninit();
unsafe {
// NOTE: deserialize performs a curve check but not a subgroup check. if that changes,
// NOTE: uncompress performs a curve check but not a subgroup check. if that changes,
// then we should encounter `unreachable` for `BLST_POINT_NOT_IN_GROUP` in tests.
match blst_p2_deserialize(affine.as_mut_ptr(), bytes.as_ref().as_ptr()) {
match blst_p1_uncompress(affine.as_mut_ptr(), compressed.as_ref().as_ptr()) {
BLST_ERROR::BLST_SUCCESS => {}
BLST_ERROR::BLST_BAD_ENCODING => return Err(ECGroupError::InvalidEncoding),
BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(ECGroupError::NotOnCurve),
other => unreachable!("{other:?}"),
}
if !blst_p2_affine_in_g2(affine.as_ptr()) {
if !blst_p1_affine_in_g1(affine.as_ptr()) {
return Err(ECGroupError::NotInGroup);
}

blst_p2_from_affine(out.as_mut_ptr(), affine.as_ptr());
blst_p1_from_affine(out.as_mut_ptr(), affine.as_ptr());
Ok(Self {
element: out.assume_init(),
})
}
}
}

#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct P2 {
element: blst_p2,
}

impl P2 {
pub const BITS: usize = 768;
pub const BYTES: usize = Self::BITS / 8;

// TODO: make available as `const`
pub fn generator() -> Self {
Expand Down Expand Up @@ -608,6 +622,47 @@ impl Mul<Fr> for P2 {
}
}

impl Compress for P2 {
const COMPRESSED: usize = Self::BYTES;

fn compress(&self, mut buf: impl AsMut<[u8]>) -> Result<(), &'static str> {
if buf.as_mut().len() < Self::COMPRESSED {
return Err("insufficient buffer length");
}
unsafe {
blst_p2_compress(buf.as_mut().as_mut_ptr(), &self.element);
}
Ok(())
}
}

impl Decompress for P2 {
type Error = ECGroupError;

fn decompress(compressed: impl AsRef<[u8]>) -> Result<Self, Self::Error> {
let mut affine = MaybeUninit::<blst_p2_affine>::uninit();
let mut out = MaybeUninit::<blst_p2>::uninit();
unsafe {
// NOTE: uncompress performs a curve check but not a subgroup check. if that changes,
// then we should encounter `unreachable` for `BLST_POINT_NOT_IN_GROUP` in tests.
match blst_p2_uncompress(affine.as_mut_ptr(), compressed.as_ref().as_ptr()) {
BLST_ERROR::BLST_SUCCESS => {}
BLST_ERROR::BLST_BAD_ENCODING => return Err(ECGroupError::InvalidEncoding),
BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(ECGroupError::NotOnCurve),
other => unreachable!("{other:?}"),
}
if !blst_p2_affine_in_g2(affine.as_ptr()) {
return Err(ECGroupError::NotInGroup);
}

blst_p2_from_affine(out.as_mut_ptr(), affine.as_ptr());
Ok(Self {
element: out.assume_init(),
})
}
}
}

pub fn verify_pairings((a1, a2): (P1, P2), (b1, b2): (P1, P2)) -> bool {
let mut a1_neg_affine = MaybeUninit::<blst_p1_affine>::uninit();
let mut a2_affine = MaybeUninit::<blst_p2_affine>::uninit();
Expand Down
42 changes: 25 additions & 17 deletions src/kzg/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
use super::{Bytes32, Bytes48, Commitment, Error, Polynomial, Proof};
use crate::{
blob::{Blob, Error as BlobError},
bls::{self, ECGroupError, Error as BlsError, Fr, P1, P2},
bls::{self, Decompress, ECGroupError, Error as BlsError, Fr, P1, P2},
math,
};

Expand Down Expand Up @@ -60,7 +60,7 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
// TODO: skip unnecessary allocation
let point = FixedBytes::<48>::from_slice(point);
let point =
P1::deserialize(&point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?;
P1::decompress(point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?;
g1_lagrange[i] = point;
}
let g1_lagrange_brp = math::bit_reversal_permutation_boxed_array(g1_lagrange.as_slice());
Expand All @@ -75,7 +75,7 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
// TODO: skip unnecessary allocation
let point = FixedBytes::<96>::from_slice(point);
let point =
P2::deserialize(&point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?;
P2::decompress(point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?;
g2_monomial[i] = point;
}

Expand Down Expand Up @@ -108,8 +108,8 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
point: &Bytes32,
eval: &Bytes32,
) -> Result<bool, Error> {
let proof = Proof::deserialize(proof).map_err(|err| Error::from(BlsError::ECGroup(err)))?;
let commitment = Commitment::deserialize(commitment)
let proof = Proof::decompress(proof).map_err(|err| Error::from(BlsError::ECGroup(err)))?;
let commitment = Commitment::decompress(commitment)
.map_err(|err| Error::from(BlsError::ECGroup(err)))?;
let point =
Fr::from_be_slice(point).map_err(|err| Error::from(BlsError::FiniteField(err)))?;
Expand Down Expand Up @@ -184,7 +184,7 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {

pub fn blob_proof(&self, blob: impl AsRef<[u8]>, commitment: &Bytes48) -> Result<Proof, Error> {
let blob: Blob<G1> = Blob::from_slice(blob).map_err(Error::from)?;
let commitment = Commitment::deserialize(commitment)
let commitment = Commitment::decompress(commitment)
.map_err(|err| Error::from(BlsError::ECGroup(err)))?;
let proof = self.blob_proof_inner(&blob, &commitment);
Ok(proof)
Expand Down Expand Up @@ -220,9 +220,9 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
proof: &Bytes48,
) -> Result<bool, Error> {
let blob: Blob<G1> = Blob::from_slice(blob).map_err(Error::from)?;
let commitment = Commitment::deserialize(commitment)
let commitment = Commitment::decompress(commitment)
.map_err(|err| Error::from(BlsError::ECGroup(err)))?;
let proof = Proof::deserialize(proof).map_err(|err| Error::from(BlsError::ECGroup(err)))?;
let proof = Proof::decompress(proof).map_err(|err| Error::from(BlsError::ECGroup(err)))?;

let verified = self.verify_blob_proof_inner(&blob, &commitment, &proof);
Ok(verified)
Expand Down Expand Up @@ -271,12 +271,11 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
let commitments: Result<Vec<Commitment>, _> = commitments
.as_ref()
.iter()
.map(Commitment::deserialize)
.map(Commitment::decompress)
.collect();
let commitments = commitments.map_err(|err| Error::from(BlsError::ECGroup(err)))?;

let proofs: Result<Vec<Proof>, _> =
proofs.as_ref().iter().map(Proof::deserialize).collect();
let proofs: Result<Vec<Proof>, _> = proofs.as_ref().iter().map(Proof::decompress).collect();
let proofs = proofs.map_err(|err| Error::from(BlsError::ECGroup(err)))?;

let verified = self.verify_blob_proof_batch_inner(blobs, commitments, proofs);
Expand All @@ -288,9 +287,12 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
mod tests {
use super::*;

use crate::kzg::spec::{
BlobToCommitment, ComputeBlobProof, ComputeProof, VerifyBlobProof, VerifyBlobProofBatch,
VerifyProof,
use crate::{
bls::Compress,
kzg::spec::{
BlobToCommitment, ComputeBlobProof, ComputeProof, VerifyBlobProof,
VerifyBlobProofBatch, VerifyProof,
},
};

use std::{
Expand Down Expand Up @@ -343,7 +345,9 @@ mod tests {
continue;
};
let (expected_proof, expected_y) = expected.unwrap();
assert_eq!(proof.serialize(), expected_proof);
let mut proof_bytes = [0u8; Proof::COMPRESSED];
proof.compress(&mut proof_bytes).unwrap();
assert_eq!(proof_bytes, expected_proof);
assert_eq!(y.to_be_bytes(), expected_y);
}
}
Expand All @@ -367,7 +371,9 @@ mod tests {
continue;
};
let expected = expected.unwrap();
assert_eq!(proof.serialize(), expected);
let mut proof_bytes = [0u8; Proof::COMPRESSED];
proof.compress(&mut proof_bytes).unwrap();
assert_eq!(proof_bytes, expected);
}
}

Expand All @@ -387,7 +393,9 @@ mod tests {
continue;
};
let expected = expected.unwrap();
assert_eq!(commitment.serialize(), expected);
let mut commitment_bytes = [0u8; Commitment::COMPRESSED];
commitment.compress(&mut commitment_bytes).unwrap();
assert_eq!(commitment_bytes, expected);
}
}

Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod bls;
mod math;

pub use bls::{Compress, Decompress};
pub mod blob;
pub mod kzg;
Loading