Skip to content

Commit

Permalink
online-phase: algebra: scalar: Optimize (de)serialization
Browse files Browse the repository at this point in the history
This has large effects throughout any stack built on top of the `Scalar`
type, so we use the `ark_serialize` `CanonicalSerialize` and
`CanonicalDeserialize` traits.
  • Loading branch information
joeykraut committed Aug 19, 2024
1 parent 60a748a commit c2a363a
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 4 deletions.
5 changes: 5 additions & 0 deletions online-phase/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ path = "integration/main.rs"
harness = false
required-features = ["test_helpers"]

[[bench]]
name = "scalar_serialization"
harness = false
required-features = ["benchmarks", "test_helpers"]

[[bench]]
name = "batch_ops"
harness = false
Expand Down
61 changes: 61 additions & 0 deletions online-phase/benches/scalar_serialization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use std::time::{Duration, Instant};

use ark_mpc::{algebra::Scalar, test_helpers::TestCurve};
use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};
use rand::thread_rng;

/// Benchmark the serialization of scalars
fn bench_scalar_serialization(c: &mut Criterion) {
let mut rng = thread_rng();
let mut group = c.benchmark_group("scalar_serialization");
group.throughput(Throughput::Elements(1));

group.bench_function("scalar_serialization", |b| {
b.iter_custom(|n_iters| {
let mut total_time = Duration::from_secs(0);
for _ in 0..n_iters {
let scalar = Scalar::<TestCurve>::random(&mut rng);

let start = Instant::now();
let bytes = serde_json::to_value(scalar).unwrap();
total_time += start.elapsed();

black_box(bytes);
}
total_time
})
});
}

/// Benchmark the deserialization of scalars
fn bench_scalar_deserialization(c: &mut Criterion) {
let mut rng = thread_rng();
let mut group = c.benchmark_group("scalar_deserialization");
group.throughput(Throughput::Elements(1));

group.bench_function("scalar_deserialization", |b| {
b.iter_custom(|n_iters| {
let mut total_time = Duration::from_secs(0);
for _ in 0..n_iters {
let scalar = Scalar::<TestCurve>::random(&mut rng);
let serialized = serde_json::to_value(scalar).unwrap();

// Time deserialization only
let start = Instant::now();
let deserialized: Scalar<TestCurve> = serde_json::from_value(serialized).unwrap();
total_time += start.elapsed();

black_box(deserialized);
}

total_time
})
});
}

criterion_group! {
name = scalar_ops;
config = Criterion::default();
targets = bench_scalar_serialization, bench_scalar_deserialization
}
criterion_main!(scalar_ops);
22 changes: 18 additions & 4 deletions online-phase/src/algebra/scalar/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use std::{
use ark_ec::CurveGroup;
use ark_ff::{batch_inversion, FftField, Field, One, PrimeField, Zero};
use ark_poly::EvaluationDomain;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use ark_std::UniformRand;
use itertools::Itertools;
use num_bigint::BigUint;
Expand Down Expand Up @@ -168,16 +169,18 @@ impl<C: CurveGroup> Display for Scalar<C> {

impl<C: CurveGroup> Serialize for Scalar<C> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let bytes = self.to_bytes_be();
bytes.serialize(serializer)
let mut bytes = Vec::with_capacity(n_bytes_field::<C::ScalarField>());
self.0.serialize_uncompressed(&mut bytes).map_err(serde::ser::Error::custom)?;
serializer.serialize_bytes(&bytes)
}
}

impl<'de, C: CurveGroup> Deserialize<'de> for Scalar<C> {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let bytes = <Vec<u8>>::deserialize(deserializer)?;
let scalar = Scalar::from_be_bytes_mod_order(&bytes);
Ok(scalar)
let inner = C::ScalarField::deserialize_uncompressed(bytes.as_slice())
.map_err(serde::de::Error::custom)?;
Ok(Scalar(inner))
}
}

Expand Down Expand Up @@ -720,6 +723,17 @@ mod test {
use itertools::Itertools;
use rand::{thread_rng, Rng, RngCore};

/// Tests serialization and deserialization of scalars
#[test]
fn test_scalar_serialization() {
let mut rng = thread_rng();
let scalar = Scalar::<TestCurve>::random(&mut rng);

let bytes = serde_json::to_vec(&scalar).unwrap();
let deserialized: Scalar<TestCurve> = serde_json::from_slice(&bytes).unwrap();
assert_eq!(scalar, deserialized);
}

/// Tests addition of raw scalars in a circuit
#[tokio::test]
async fn test_scalar_add() {
Expand Down

0 comments on commit c2a363a

Please sign in to comment.