From 67509daca7064fe3e10e44887a7227a70ef744a5 Mon Sep 17 00:00:00 2001 From: Jacob Kaufmann Date: Wed, 13 Mar 2024 12:20:30 -0600 Subject: [PATCH] refactor: add byte array type for deserialization (#46) --- Cargo.toml | 2 +- src/bytes.rs | 52 ++++++++++++++++++++++++++++++++++++++++++++++++ src/kzg/setup.rs | 33 +++++++++--------------------- src/kzg/spec.rs | 12 ++++++----- src/lib.rs | 1 + 5 files changed, 70 insertions(+), 30 deletions(-) create mode 100644 src/bytes.rs diff --git a/Cargo.toml b/Cargo.toml index 14ce3bf..e7106a3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,9 +4,9 @@ version = "0.1.0" edition = "2021" [dependencies] -alloy-primitives = { version = "0.4.2", features = ["std", "serde"] } blst = "0.3.11" criterion = "0.5.1" +hex = "0.4.3" rand = { version = "0.8.5", optional = true } serde = { version = "1.0.189", features = ["derive"] } serde_json = "1.0.107" diff --git a/src/bytes.rs b/src/bytes.rs new file mode 100644 index 0000000..24ca55d --- /dev/null +++ b/src/bytes.rs @@ -0,0 +1,52 @@ +use serde::{de::Visitor, Deserialize}; + +#[derive(Clone, Debug)] +pub struct Bytes(Vec); + +impl AsRef<[u8]> for Bytes { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + +pub struct BytesVisitor; + +impl<'de> Visitor<'de> for BytesVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { + formatter.write_str( + "a variable-length byte array represented by a raw byte array or a hex-encoded string", + ) + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + let v = v.strip_prefix("0x").unwrap_or(v); + let v = hex::decode(v).map_err(E::custom)?; + Ok(v) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: serde::de::Error, + { + let v = hex::decode(v).map_err(E::custom)?; + Ok(v) + } +} + +impl<'de> Deserialize<'de> for Bytes { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + if deserializer.is_human_readable() { + deserializer.deserialize_str(BytesVisitor).map(Bytes) + } else { + deserializer.deserialize_bytes(BytesVisitor).map(Bytes) + } + } +} diff --git a/src/kzg/setup.rs b/src/kzg/setup.rs index 2e97e99..2124ff3 100644 --- a/src/kzg/setup.rs +++ b/src/kzg/setup.rs @@ -7,23 +7,23 @@ use std::{ use super::{Bytes32, Bytes48, Commitment, Error, Polynomial, Proof}; use crate::{ blob::{Blob, Error as BlobError}, - bls::{self, Decompress, ECGroupError, Error as BlsError, Fr, P1, P2}, + bls::{self, Decompress, Error as BlsError, Fr, P1, P2}, + bytes::Bytes, math, }; -use alloy_primitives::{hex, Bytes, FixedBytes}; +use serde::Deserialize; #[derive(Debug)] pub enum LoadSetupError { Bls(BlsError), Io(io::Error), - Hex(hex::FromHexError), Serde(serde_json::Error), InvalidLenG1Lagrange, InvalidLenG2Monomial, } -#[derive(serde::Deserialize, serde::Serialize)] +#[derive(Deserialize)] struct SetupUnchecked { g1_lagrange: Vec, g2_monomial: Vec, @@ -52,13 +52,6 @@ impl Setup { let mut g1_lagrange: Box<[P1; G1]> = Box::new([P1::default(); G1]); for (i, point) in setup.g1_lagrange.iter().enumerate() { - if point.len() != 48 { - return Err(LoadSetupError::Bls(BlsError::from( - ECGroupError::InvalidEncoding, - ))); - } - // TODO: skip unnecessary allocation - let point = FixedBytes::<48>::from_slice(point); let point = P1::decompress(point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?; g1_lagrange[i] = point; @@ -67,13 +60,6 @@ impl Setup { let mut g2_monomial: Box<[P2; G2]> = Box::new([P2::default(); G2]); for (i, point) in setup.g2_monomial.iter().enumerate() { - if point.len() != 96 { - return Err(LoadSetupError::Bls(BlsError::from( - ECGroupError::InvalidEncoding, - ))); - } - // TODO: skip unnecessary allocation - let point = FixedBytes::<96>::from_slice(point); let point = P2::decompress(point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?; g2_monomial[i] = point; @@ -287,12 +273,11 @@ impl Setup { mod tests { use super::*; - use crate::{ - bls::Compress, - kzg::spec::{ - BlobToCommitment, ComputeBlobProof, ComputeProof, VerifyBlobProof, - VerifyBlobProofBatch, VerifyProof, - }, + use crate::bls::Compress; + + use crate::kzg::spec::{ + BlobToCommitment, ComputeBlobProof, ComputeProof, VerifyBlobProof, VerifyBlobProofBatch, + VerifyProof, }; use std::{ diff --git a/src/kzg/spec.rs b/src/kzg/spec.rs index c6473f5..54557b7 100644 --- a/src/kzg/spec.rs +++ b/src/kzg/spec.rs @@ -1,18 +1,20 @@ -use alloy_primitives::{Bytes, FixedBytes}; use serde::Deserialize; -use crate::bls::{Fr, P1}; +use crate::{ + bls::{Fr, P1}, + bytes::Bytes, +}; use super::{Bytes32, Bytes48}; fn bytes32_from_bytes(bytes: &Bytes) -> Option { - let bytes = FixedBytes::<{ Fr::BYTES }>::try_from(bytes.as_ref()).ok(); + let bytes: Option<[u8; Fr::BYTES]> = TryFrom::try_from(bytes.as_ref()).ok(); bytes.map(Into::::into) } fn bytes48_from_bytes(bytes: &Bytes) -> Option { - let bytes = FixedBytes::<{ P1::BYTES }>::try_from(bytes.as_ref()).ok()?; - Some(bytes.into()) + let bytes: Option<[u8; P1::BYTES]> = TryFrom::try_from(bytes.as_ref()).ok(); + bytes.map(Into::::into) } #[derive(Deserialize)] diff --git a/src/lib.rs b/src/lib.rs index 611fc29..4d9b1c0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ mod bls; +mod bytes; mod math; pub use bls::{Compress, Decompress};