Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

curve: partial precomputation #668

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl VartimePrecomputedMultiscalarMul for VartimePrecomputedStraus {

let sp = self.static_lookup_tables.len();
let dp = dynamic_lookup_tables.len();
assert_eq!(sp, static_nafs.len());
assert!(sp >= static_nafs.len());
assert_eq!(dp, dynamic_nafs.len());

// We could save some doublings by looking for the highest
Expand All @@ -99,7 +99,7 @@ impl VartimePrecomputedMultiscalarMul for VartimePrecomputedStraus {
}

#[allow(clippy::needless_range_loop)]
for i in 0..sp {
for i in 0..static_nafs.len() {
let t_ij = static_nafs[i][j];
match t_ij.cmp(&0) {
Ordering::Greater => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ pub mod spec {

let sp = self.static_lookup_tables.len();
let dp = dynamic_lookup_tables.len();
assert_eq!(sp, static_nafs.len());
assert!(sp >= static_nafs.len());
assert_eq!(dp, dynamic_nafs.len());

// We could save some doublings by looking for the highest
Expand All @@ -107,7 +107,7 @@ pub mod spec {
}

#[allow(clippy::needless_range_loop)]
for i in 0..sp {
for i in 0..static_nafs.len() {
let t_ij = static_nafs[i][j];
match t_ij.cmp(&0) {
Ordering::Greater => {
Expand Down
144 changes: 144 additions & 0 deletions curve25519-dalek/src/ristretto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1869,4 +1869,148 @@ mod test {
assert_eq!(P.compress(), R.compress());
assert_eq!(Q.compress(), R.compress());
}

#[test]
#[cfg(feature = "alloc")]
fn partial_precomputed_mixed_multiscalar_empty() {
let mut rng = rand::thread_rng();

let n_static = 16;
let n_dynamic = 8;

let static_points = (0..n_static)
.map(|_| RistrettoPoint::random(&mut rng))
.collect::<Vec<_>>();

// Use zero scalars
let static_scalars = Vec::new();

let dynamic_points = (0..n_dynamic)
.map(|_| RistrettoPoint::random(&mut rng))
.collect::<Vec<_>>();

let dynamic_scalars = (0..n_dynamic)
.map(|_| Scalar::random(&mut rng))
.collect::<Vec<_>>();

// Compute the linear combination using precomputed multiscalar multiplication
let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter());
let result_multiscalar = precomputation.vartime_mixed_multiscalar_mul(
&static_scalars,
&dynamic_scalars,
&dynamic_points,
);

// Compute the linear combination manually
let mut result_manual = RistrettoPoint::identity();
for i in 0..static_scalars.len() {
result_manual += static_points[i] * static_scalars[i];
}
for i in 0..n_dynamic {
result_manual += dynamic_points[i] * dynamic_scalars[i];
}

assert_eq!(result_multiscalar, result_manual);
}

#[test]
#[cfg(feature = "alloc")]
fn partial_precomputed_mixed_multiscalar() {
let mut rng = rand::thread_rng();

let n_static = 16;
let n_dynamic = 8;

let static_points = (0..n_static)
.map(|_| RistrettoPoint::random(&mut rng))
.collect::<Vec<_>>();

// Use one fewer scalars
let static_scalars = (0..n_static - 1)
.map(|_| Scalar::random(&mut rng))
.collect::<Vec<_>>();

let dynamic_points = (0..n_dynamic)
.map(|_| RistrettoPoint::random(&mut rng))
.collect::<Vec<_>>();

let dynamic_scalars = (0..n_dynamic)
.map(|_| Scalar::random(&mut rng))
.collect::<Vec<_>>();

// Compute the linear combination using precomputed multiscalar multiplication
let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter());
let result_multiscalar = precomputation.vartime_mixed_multiscalar_mul(
&static_scalars,
&dynamic_scalars,
&dynamic_points,
);

// Compute the linear combination manually
let mut result_manual = RistrettoPoint::identity();
for i in 0..static_scalars.len() {
result_manual += static_points[i] * static_scalars[i];
}
for i in 0..n_dynamic {
result_manual += dynamic_points[i] * dynamic_scalars[i];
}

assert_eq!(result_multiscalar, result_manual);
}

#[test]
#[cfg(feature = "alloc")]
fn partial_precomputed_multiscalar() {
let mut rng = rand::thread_rng();

let n_static = 16;

let static_points = (0..n_static)
.map(|_| RistrettoPoint::random(&mut rng))
.collect::<Vec<_>>();

// Use one fewer scalars
let static_scalars = (0..n_static - 1)
.map(|_| Scalar::random(&mut rng))
.collect::<Vec<_>>();

// Compute the linear combination using precomputed multiscalar multiplication
let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter());
let result_multiscalar = precomputation.vartime_multiscalar_mul(&static_scalars);

// Compute the linear combination manually
let mut result_manual = RistrettoPoint::identity();
for i in 0..static_scalars.len() {
result_manual += static_points[i] * static_scalars[i];
}

assert_eq!(result_multiscalar, result_manual);
}

#[test]
#[cfg(feature = "alloc")]
fn partial_precomputed_multiscalar_empty() {
let mut rng = rand::thread_rng();

let n_static = 16;

let static_points = (0..n_static)
.map(|_| RistrettoPoint::random(&mut rng))
.collect::<Vec<_>>();

// Use zero scalars
let static_scalars = Vec::new();

// Compute the linear combination using precomputed multiscalar multiplication
let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter());
let result_multiscalar = precomputation.vartime_multiscalar_mul(&static_scalars);

// Compute the linear combination manually
let mut result_manual = RistrettoPoint::identity();
for i in 0..static_scalars.len() {
result_manual += static_points[i] * static_scalars[i];
}

assert_eq!(result_multiscalar, result_manual);
}
}
22 changes: 15 additions & 7 deletions curve25519-dalek/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ pub trait VartimeMultiscalarMul {
/// to be composed into the input iterators.
///
/// All methods require that the lengths of the input iterators be
/// known and matching, as if they were `ExactSizeIterator`s. (It
/// known, as if they were `ExactSizeIterator`s. (It
/// does not require `ExactSizeIterator` only because that trait is
/// broken).
pub trait VartimePrecomputedMultiscalarMul: Sized {
Expand All @@ -306,8 +306,10 @@ pub trait VartimePrecomputedMultiscalarMul: Sized {
/// $$
/// where the \\(B_j\\) are the points that were supplied to `new`.
///
/// It is an error to call this function with iterators of
/// inconsistent lengths.
/// It is valid for \\(b_i\\) to have a shorter length than \\(B_j\\).
/// In this case, any "unused" points are ignored in the computation.
/// It is an error to call this function if \\(b_i\\) has a longer
/// length than \\(B_j\\).
///
/// The trait bound aims for maximum flexibility: the input must
/// be convertable to iterators (`I: IntoIter`), and the
Expand Down Expand Up @@ -337,8 +339,11 @@ pub trait VartimePrecomputedMultiscalarMul: Sized {
/// $$
/// where the \\(B_j\\) are the points that were supplied to `new`.
///
/// It is an error to call this function with iterators of
/// inconsistent lengths.
/// It is valid for \\(b_i\\) to have a shorter length than \\(B_j\\).
/// In this case, any "unused" points are ignored in the computation.
/// It is an error to call this function if \\(b_i\\) has a longer
/// length than \\(B_j\\), or if \\(a_i\\) and \\(A_i\\) do not have
/// the same length.
///
/// The trait bound aims for maximum flexibility: the inputs must be
/// convertable to iterators (`I: IntoIter`), and the iterator's items
Expand Down Expand Up @@ -378,8 +383,11 @@ pub trait VartimePrecomputedMultiscalarMul: Sized {
///
/// If any of the dynamic points were `None`, return `None`.
///
/// It is an error to call this function with iterators of
/// inconsistent lengths.
/// It is valid for \\(b_i\\) to have a shorter length than \\(B_j\\).
/// In this case, any "unused" points are ignored in the computation.
/// It is an error to call this function if \\(b_i\\) has a longer
/// length than \\(B_j\\), or if \\(a_i\\) and \\(A_i\\) do not have
/// the same length.
///
/// This function is particularly useful when verifying statements
/// involving compressed points. Accepting `Option<Point>` allows
Expand Down
Loading