From cf43e061439b0749e2bf5d5b4125c96aa2bb2b6c Mon Sep 17 00:00:00 2001 From: Aaron Feickert <66188213+AaronFeickert@users.noreply.github.com> Date: Mon, 1 Jul 2024 13:42:22 -0500 Subject: [PATCH] Partial precomputation --- .../serial/scalar_mul/precomputed_straus.rs | 4 +- .../vector/scalar_mul/precomputed_straus.rs | 4 +- curve25519-dalek/src/ristretto.rs | 144 ++++++++++++++++++ curve25519-dalek/src/traits.rs | 22 ++- 4 files changed, 163 insertions(+), 11 deletions(-) diff --git a/curve25519-dalek/src/backend/serial/scalar_mul/precomputed_straus.rs b/curve25519-dalek/src/backend/serial/scalar_mul/precomputed_straus.rs index 711649e21..53116c628 100644 --- a/curve25519-dalek/src/backend/serial/scalar_mul/precomputed_straus.rs +++ b/curve25519-dalek/src/backend/serial/scalar_mul/precomputed_straus.rs @@ -75,7 +75,7 @@ impl VartimePrecomputedMultiscalarMul for VartimePrecomputedStraus { let sp = self.static_lookup_tables.len(); let dp = dynamic_lookup_tables.len(); - assert_eq!(sp, static_nafs.len()); + assert!(sp >= static_nafs.len()); assert_eq!(dp, dynamic_nafs.len()); // We could save some doublings by looking for the highest @@ -99,7 +99,7 @@ impl VartimePrecomputedMultiscalarMul for VartimePrecomputedStraus { } #[allow(clippy::needless_range_loop)] - for i in 0..sp { + for i in 0..static_nafs.len() { let t_ij = static_nafs[i][j]; match t_ij.cmp(&0) { Ordering::Greater => { diff --git a/curve25519-dalek/src/backend/vector/scalar_mul/precomputed_straus.rs b/curve25519-dalek/src/backend/vector/scalar_mul/precomputed_straus.rs index 515b4040c..1f16ab3e1 100644 --- a/curve25519-dalek/src/backend/vector/scalar_mul/precomputed_straus.rs +++ b/curve25519-dalek/src/backend/vector/scalar_mul/precomputed_straus.rs @@ -83,7 +83,7 @@ pub mod spec { let sp = self.static_lookup_tables.len(); let dp = dynamic_lookup_tables.len(); - assert_eq!(sp, static_nafs.len()); + assert!(sp >= static_nafs.len()); assert_eq!(dp, dynamic_nafs.len()); // We could save some doublings by looking for the highest @@ -107,7 +107,7 @@ pub mod spec { } #[allow(clippy::needless_range_loop)] - for i in 0..sp { + for i in 0..static_nafs.len() { let t_ij = static_nafs[i][j]; match t_ij.cmp(&0) { Ordering::Greater => { diff --git a/curve25519-dalek/src/ristretto.rs b/curve25519-dalek/src/ristretto.rs index c9d16aba3..a3e080cbf 100644 --- a/curve25519-dalek/src/ristretto.rs +++ b/curve25519-dalek/src/ristretto.rs @@ -1869,4 +1869,148 @@ mod test { assert_eq!(P.compress(), R.compress()); assert_eq!(Q.compress(), R.compress()); } + + #[test] + #[cfg(feature = "alloc")] + fn partial_precomputed_mixed_multiscalar_empty() { + let mut rng = rand::thread_rng(); + + let n_static = 16; + let n_dynamic = 8; + + let static_points = (0..n_static) + .map(|_| RistrettoPoint::random(&mut rng)) + .collect::>(); + + // Use zero scalars + let static_scalars = Vec::new(); + + let dynamic_points = (0..n_dynamic) + .map(|_| RistrettoPoint::random(&mut rng)) + .collect::>(); + + let dynamic_scalars = (0..n_dynamic) + .map(|_| Scalar::random(&mut rng)) + .collect::>(); + + // Compute the linear combination using precomputed multiscalar multiplication + let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter()); + let result_multiscalar = precomputation.vartime_mixed_multiscalar_mul( + &static_scalars, + &dynamic_scalars, + &dynamic_points, + ); + + // Compute the linear combination manually + let mut result_manual = RistrettoPoint::identity(); + for i in 0..static_scalars.len() { + result_manual += static_points[i] * static_scalars[i]; + } + for i in 0..n_dynamic { + result_manual += dynamic_points[i] * dynamic_scalars[i]; + } + + assert_eq!(result_multiscalar, result_manual); + } + + #[test] + #[cfg(feature = "alloc")] + fn partial_precomputed_mixed_multiscalar() { + let mut rng = rand::thread_rng(); + + let n_static = 16; + let n_dynamic = 8; + + let static_points = (0..n_static) + .map(|_| RistrettoPoint::random(&mut rng)) + .collect::>(); + + // Use one fewer scalars + let static_scalars = (0..n_static - 1) + .map(|_| Scalar::random(&mut rng)) + .collect::>(); + + let dynamic_points = (0..n_dynamic) + .map(|_| RistrettoPoint::random(&mut rng)) + .collect::>(); + + let dynamic_scalars = (0..n_dynamic) + .map(|_| Scalar::random(&mut rng)) + .collect::>(); + + // Compute the linear combination using precomputed multiscalar multiplication + let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter()); + let result_multiscalar = precomputation.vartime_mixed_multiscalar_mul( + &static_scalars, + &dynamic_scalars, + &dynamic_points, + ); + + // Compute the linear combination manually + let mut result_manual = RistrettoPoint::identity(); + for i in 0..static_scalars.len() { + result_manual += static_points[i] * static_scalars[i]; + } + for i in 0..n_dynamic { + result_manual += dynamic_points[i] * dynamic_scalars[i]; + } + + assert_eq!(result_multiscalar, result_manual); + } + + #[test] + #[cfg(feature = "alloc")] + fn partial_precomputed_multiscalar() { + let mut rng = rand::thread_rng(); + + let n_static = 16; + + let static_points = (0..n_static) + .map(|_| RistrettoPoint::random(&mut rng)) + .collect::>(); + + // Use one fewer scalars + let static_scalars = (0..n_static - 1) + .map(|_| Scalar::random(&mut rng)) + .collect::>(); + + // Compute the linear combination using precomputed multiscalar multiplication + let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter()); + let result_multiscalar = precomputation.vartime_multiscalar_mul(&static_scalars); + + // Compute the linear combination manually + let mut result_manual = RistrettoPoint::identity(); + for i in 0..static_scalars.len() { + result_manual += static_points[i] * static_scalars[i]; + } + + assert_eq!(result_multiscalar, result_manual); + } + + #[test] + #[cfg(feature = "alloc")] + fn partial_precomputed_multiscalar_empty() { + let mut rng = rand::thread_rng(); + + let n_static = 16; + + let static_points = (0..n_static) + .map(|_| RistrettoPoint::random(&mut rng)) + .collect::>(); + + // Use zero scalars + let static_scalars = Vec::new(); + + // Compute the linear combination using precomputed multiscalar multiplication + let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter()); + let result_multiscalar = precomputation.vartime_multiscalar_mul(&static_scalars); + + // Compute the linear combination manually + let mut result_manual = RistrettoPoint::identity(); + for i in 0..static_scalars.len() { + result_manual += static_points[i] * static_scalars[i]; + } + + assert_eq!(result_multiscalar, result_manual); + } } diff --git a/curve25519-dalek/src/traits.rs b/curve25519-dalek/src/traits.rs index 322787db5..ea7ca3be7 100644 --- a/curve25519-dalek/src/traits.rs +++ b/curve25519-dalek/src/traits.rs @@ -285,7 +285,7 @@ pub trait VartimeMultiscalarMul { /// to be composed into the input iterators. /// /// All methods require that the lengths of the input iterators be -/// known and matching, as if they were `ExactSizeIterator`s. (It +/// known, as if they were `ExactSizeIterator`s. (It /// does not require `ExactSizeIterator` only because that trait is /// broken). pub trait VartimePrecomputedMultiscalarMul: Sized { @@ -306,8 +306,10 @@ pub trait VartimePrecomputedMultiscalarMul: Sized { /// $$ /// where the \\(B_j\\) are the points that were supplied to `new`. /// - /// It is an error to call this function with iterators of - /// inconsistent lengths. + /// It is valid for \\(b_i\\) to have a shorter length than \\(B_j\\). + /// In this case, any "unused" points are ignored in the computation. + /// It is an error to call this function if \\(b_i\\) has a longer + /// length than \\(B_j\\). /// /// The trait bound aims for maximum flexibility: the input must /// be convertable to iterators (`I: IntoIter`), and the @@ -337,8 +339,11 @@ pub trait VartimePrecomputedMultiscalarMul: Sized { /// $$ /// where the \\(B_j\\) are the points that were supplied to `new`. /// - /// It is an error to call this function with iterators of - /// inconsistent lengths. + /// It is valid for \\(b_i\\) to have a shorter length than \\(B_j\\). + /// In this case, any "unused" points are ignored in the computation. + /// It is an error to call this function if \\(b_i\\) has a longer + /// length than \\(B_j\\), or if \\(a_i\\) and \\(A_i\\) do not have + /// the same length. /// /// The trait bound aims for maximum flexibility: the inputs must be /// convertable to iterators (`I: IntoIter`), and the iterator's items @@ -378,8 +383,11 @@ pub trait VartimePrecomputedMultiscalarMul: Sized { /// /// If any of the dynamic points were `None`, return `None`. /// - /// It is an error to call this function with iterators of - /// inconsistent lengths. + /// It is valid for \\(b_i\\) to have a shorter length than \\(B_j\\). + /// In this case, any "unused" points are ignored in the computation. + /// It is an error to call this function if \\(b_i\\) has a longer + /// length than \\(B_j\\), or if \\(a_i\\) and \\(A_i\\) do not have + /// the same length. /// /// This function is particularly useful when verifying statements /// involving compressed points. Accepting `Option` allows