Skip to content

Commit

Permalink
refactor: add serde feature with conditional compilation (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobkaufmann authored Mar 22, 2024
1 parent 7ba9bdb commit fcd2867
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 45 deletions.
11 changes: 6 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ edition = "2021"

[dependencies]
blst = "0.3.11"
hex = "0.4.3"
hex = { version = "0.4.3", optional = true }
rand = { version = "0.8.5", optional = true }
serde = { version = "1.0.189", features = ["derive"] }
serde_json = "1.0.107"
serde_yaml = "0.9.25"
serde = { version = "1.0.189", features = ["derive"], optional = true }
serde_json = { version = "1.0.107", optional = true }
serde_yaml = { version = "0.9.25", optional = true }

[dev-dependencies]
criterion = "0.5.1"
Expand All @@ -18,8 +18,9 @@ rand = "0.8.5"
[features]
default = []
rand = ["dep:rand"]
serde = ["dep:hex", "dep:serde", "dep:serde_json", "dep:serde_yaml"]

[[bench]]
name = "kzg"
harness = false
required-features = ["rand"]
required-features = ["rand", "serde"]
2 changes: 1 addition & 1 deletion benches/kzg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rand::thread_rng;

pub fn benchmark(c: &mut Criterion) {
let path = format!("{}/trusted_setup_4096.json", env!("CARGO_MANIFEST_DIR"));
let kzg = Setup::<4096, 65>::load(path).unwrap();
let kzg = Setup::<4096, 65>::load_json(path).unwrap();

let batch_sizes = [1usize, 2, 4, 8, 16, 32, 64, 128];
let max_batch_size = *batch_sizes.last().unwrap();
Expand Down
76 changes: 42 additions & 34 deletions src/bytes.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use serde::{de::Visitor, Deserialize};

#[derive(Clone, Debug)]
pub struct Bytes(Vec<u8>);

Expand All @@ -9,44 +7,54 @@ impl AsRef<[u8]> for Bytes {
}
}

pub struct BytesVisitor;
#[cfg(any(test, feature = "serde"))]
pub mod serde {
use super::*;

impl<'de> Visitor<'de> for BytesVisitor {
type Value = Vec<u8>;
use ::serde::{
de::{Error, Visitor},
Deserialize, Deserializer,
};

fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
formatter.write_str(
"a variable-length byte array represented by a raw byte array or a hex-encoded string",
)
}
struct BytesVisitor;

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
let v = v.strip_prefix("0x").unwrap_or(v);
let v = hex::decode(v).map_err(E::custom)?;
Ok(v)
}
impl<'de> Visitor<'de> for BytesVisitor {
type Value = Vec<u8>;

fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
formatter.write_str(
"a variable-length byte array represented by a raw byte array or a hex-encoded string",
)
}

fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
let v = hex::decode(v).map_err(E::custom)?;
Ok(v)
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
let v = v.strip_prefix("0x").unwrap_or(v);
let v = hex::decode(v).map_err(E::custom)?;
Ok(v)
}

fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: Error,
{
let v = hex::decode(v).map_err(E::custom)?;
Ok(v)
}
}
}

impl<'de> Deserialize<'de> for Bytes {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
if deserializer.is_human_readable() {
deserializer.deserialize_str(BytesVisitor).map(Bytes)
} else {
deserializer.deserialize_bytes(BytesVisitor).map(Bytes)
impl<'de> Deserialize<'de> for Bytes {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
if deserializer.is_human_readable() {
deserializer.deserialize_str(BytesVisitor).map(Bytes)
} else {
deserializer.deserialize_bytes(BytesVisitor).map(Bytes)
}
}
}
}
16 changes: 11 additions & 5 deletions src/kzg/setup.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[cfg(feature = "serde")]
use std::{
fs::File,
io::{self, BufReader},
Expand All @@ -8,12 +9,15 @@ use super::{Bytes32, Bytes48, Commitment, Error, Polynomial, Proof};
use crate::{
blob::{Blob, Error as BlobError},
bls::{self, Decompress, Error as BlsError, Fr, P1, P2},
bytes::Bytes,
math,
};

#[cfg(feature = "serde")]
use crate::{bytes::Bytes, math};

#[cfg(feature = "serde")]
use serde::Deserialize;

#[cfg(feature = "serde")]
#[derive(Debug)]
pub enum LoadSetupError {
Bls(BlsError),
Expand All @@ -23,6 +27,7 @@ pub enum LoadSetupError {
InvalidLenG2Monomial,
}

#[cfg(feature = "serde")]
#[derive(Deserialize)]
struct SetupUnchecked {
g1_lagrange: Vec<Bytes>,
Expand All @@ -37,7 +42,8 @@ pub struct Setup<const G1: usize, const G2: usize> {
}

impl<const G1: usize, const G2: usize> Setup<G1, G2> {
pub fn load(path: impl AsRef<Path>) -> Result<Self, LoadSetupError> {
#[cfg(feature = "serde")]
pub fn load_json(path: impl AsRef<Path>) -> Result<Self, LoadSetupError> {
let file = File::open(path).map_err(LoadSetupError::Io)?;
let reader = BufReader::new(file);
let setup: SetupUnchecked =
Expand Down Expand Up @@ -269,7 +275,7 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
}
}

#[cfg(test)]
#[cfg(all(test, feature = "serde"))]
mod tests {
use super::*;

Expand All @@ -293,7 +299,7 @@ mod tests {
fn setup() -> Setup<FIELD_ELEMENTS_PER_BLOB, SETUP_G2_LEN> {
let path = format!("{}/trusted_setup_4096.json", env!("CARGO_MANIFEST_DIR"));
let path = PathBuf::from(path);
Setup::load(path).unwrap()
Setup::load_json(path).unwrap()
}

fn consensus_spec_test_files(dir: impl AsRef<str>) -> impl Iterator<Item = File> {
Expand Down

0 comments on commit fcd2867

Please sign in to comment.