Skip to content

Commit d03b31f

Browse files
authored
Merge dense MLEs & update documentation (#763)
* add a method to merge polys * a slice of references should suffice for merge * implement `AsRef` for DenseMLE * change `merge` to take `AsRef<Self>` * Add tests for merging unequal polys * Update doc examples for `evaluate` * Update doc example for `from_evaluations_vec` * Update doc for `fix_variables` The resulting polynomial is in 1 variable only, no need to index it * rename internal variable and remove redundant comment * use the extracted variable for next pow of two * add changelog entry * the argument to `merge` is now an iterator instead of slice * Apply suggestions from code review * Update poly/src/evaluations/multivariate/multilinear/dense.rs
1 parent 9ce37d5 commit d03b31f

File tree

2 files changed

+151
-6
lines changed

2 files changed

+151
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
- [\#691](https://github.com/arkworks-rs/algebra/pull/691) (`ark-poly`) Implement `Polynomial` for `SparseMultilinearExtension` and `DenseMultilinearExtension`.
3232
- [\#693](https://github.com/arkworks-rs/algebra/pull/693) (`ark-serialize`) Add `serialize_to_vec!` convenience macro.
3333
- [\#713](https://github.com/arkworks-rs/algebra/pull/713) (`ark-ff`) Add support for bitwise operations AND, OR, and XOR between `BigInteger`.
34+
- [\#763](https://github.com/arkworks-rs/algebra/pull/763) (`ark-poly`) Add `concat` to concatenate evaluation tables of `DenseMultilinearPolynomial`s.
3435
- [\#811](https://github.com/arkworks-rs/algebra/pull/811) (`ark-serialize`) Implement `Valid` & `CanonicalDeserialize` for `Rc`.
3536

3637
### Improvements

poly/src/evaluations/multivariate/multilinear/dense.rs

Lines changed: 150 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
99
use ark_std::{
1010
fmt,
1111
fmt::Formatter,
12+
iter::IntoIterator,
13+
log2,
1214
ops::{Add, AddAssign, Index, Neg, Sub, SubAssign},
1315
rand::Rng,
1416
slice::{Iter, IterMut},
@@ -36,7 +38,22 @@ impl<F: Field> DenseMultilinearExtension<F> {
3638

3739
/// Construct a new polynomial from a list of evaluations where the index
3840
/// represents a point in {0,1}^`num_vars` in little endian form. For
39-
/// example, `0b1011` represents `P(1,1,0,1)`
41+
/// example, `0b1011` represents `P(1,1,0,1)`.
42+
///
43+
/// # Example
44+
/// ```
45+
/// use ark_test_curves::bls12_381::Fr;
46+
/// use ark_poly::{MultilinearExtension, Polynomial, DenseMultilinearExtension};
47+
///
48+
/// // Construct a 2-variate MLE, which takes value 1 at (x_0, x_1) = (0, 1)
49+
/// // (i.e. 0b01, or index 2 in little endian)
50+
/// // f1(x_0, x_1) = x_1*(1-x_0)
51+
/// let mle = DenseMultilinearExtension::from_evaluations_vec(
52+
/// 2, vec![0, 0, 1, 0].iter().map(|x| Fr::from(*x as u64)).collect()
53+
/// );
54+
/// let eval = mle.evaluate(&vec![Fr::from(-2), Fr::from(17)]); // point = (x_0, x_1)
55+
/// assert_eq!(eval, Fr::from(51));
56+
/// ```
4057
pub fn from_evaluations_vec(num_vars: usize, evaluations: Vec<F>) -> Self {
4158
// assert that the number of variables matches the size of evaluations
4259
assert_eq!(
@@ -82,6 +99,64 @@ impl<F: Field> DenseMultilinearExtension<F> {
8299
pub fn iter_mut(&mut self) -> IterMut<'_, F> {
83100
self.evaluations.iter_mut()
84101
}
102+
103+
/// Concatenate the evaluation tables of multiple polynomials.
104+
/// If the combined table size is not a power of two, pad the table with zeros.
105+
///
106+
/// # Example
107+
/// ```
108+
/// use ark_test_curves::bls12_381::Fr;
109+
/// use ark_poly::{MultilinearExtension, Polynomial, DenseMultilinearExtension};
110+
/// use ark_ff::One;
111+
///
112+
/// // Construct a 2-variate multilinear polynomial f1
113+
/// // f1(x_0, x_1) = 2*(1-x_1)*(1-x_0) + 3*(1-x_1)*x_0 + 2*x_1*(1-x_0) + 6*x_1*x_0
114+
/// let mle_1 = DenseMultilinearExtension::from_evaluations_vec(
115+
/// 2, vec![2, 3, 2, 6].iter().map(|x| Fr::from(*x as u64)).collect()
116+
/// );
117+
/// // Construct another 2-variate MLE f2
118+
/// // f2(x_0, x_1) = 1*x_1*x_0
119+
/// let mle_2 = DenseMultilinearExtension::from_evaluations_vec(
120+
/// 2, vec![0, 0, 0, 1].iter().map(|x| Fr::from(*x as u64)).collect()
121+
/// );
122+
/// let mle = DenseMultilinearExtension::concat(&[&mle_1, &mle_2]);
123+
/// // The resulting polynomial is 3-variate:
124+
/// // f3(x_0, x_1, x_2) = (1 - x_2)*f1(x_0, x_1) + x_2*f2(x_0, x_1)
125+
/// // Evaluate it at a random point (1, 17, 3)
126+
/// let point = vec![Fr::one(), Fr::from(17), Fr::from(3)];
127+
/// let eval_1 = mle_1.evaluate(&point[..2].to_vec());
128+
/// let eval_2 = mle_2.evaluate(&point[..2].to_vec());
129+
/// let eval_combined = mle.evaluate(&point);
130+
///
131+
/// assert_eq!(eval_combined, (Fr::one() - point[2]) * eval_1 + point[2] * eval_2);
132+
pub fn concat(polys: impl IntoIterator<Item = impl AsRef<Self>> + Clone) -> Self {
133+
// for efficient allocation into the concatenated vector, we need to know the total length
134+
// in advance, so we actually need to iterate twice. Cloning the iterator is cheap.
135+
let polys_iter_cloned = polys.clone().into_iter();
136+
137+
let total_len: usize = polys
138+
.into_iter()
139+
.map(|poly| poly.as_ref().evaluations.len())
140+
.sum();
141+
142+
let next_pow_of_two = total_len.next_power_of_two();
143+
let num_vars = log2(next_pow_of_two);
144+
let mut evaluations: Vec<F> = Vec::with_capacity(next_pow_of_two);
145+
146+
for poly in polys_iter_cloned {
147+
evaluations.extend_from_slice(&poly.as_ref().evaluations.as_slice());
148+
}
149+
150+
evaluations.resize(next_pow_of_two, F::zero());
151+
152+
Self::from_evaluations_slice(num_vars as usize, &evaluations)
153+
}
154+
}
155+
156+
impl<F: Field> AsRef<DenseMultilinearExtension<F>> for DenseMultilinearExtension<F> {
157+
fn as_ref(&self) -> &DenseMultilinearExtension<F> {
158+
self
159+
}
85160
}
86161

87162
impl<F: Field> MultilinearExtension<F> for DenseMultilinearExtension<F> {
@@ -118,8 +193,8 @@ impl<F: Field> MultilinearExtension<F> for DenseMultilinearExtension<F> {
118193
/// 2, vec![0, 1, 2, 6].iter().map(|x| Fr::from(*x as u64)).collect()
119194
/// );
120195
///
121-
/// // Bind the first variable of the MLE to the value 5, resulting in
122-
/// // the new polynomial 5 + 17 * x_1
196+
/// // Bind the first variable of the MLE, x_0, to the value 5, resulting in
197+
/// // a new polynomial in one variable: 5 + 17 * x
123198
/// let bound = mle.fix_variables(&[Fr::from(5)]);
124199
///
125200
/// assert_eq!(bound.to_evaluations(), vec![Fr::from(5), Fr::from(22)]);
@@ -298,14 +373,15 @@ impl<F: Field> Polynomial<F> for DenseMultilinearExtension<F> {
298373
/// # use ark_poly::{MultilinearExtension, DenseMultilinearExtension, Polynomial};
299374
/// # use ark_ff::One;
300375
///
301-
/// // The two-variate polynomial x_0 + 3 * x_0 * x_1 + 2 evaluates to [2, 3, 2, 6]
302-
/// // in the two-dimensional hypercube with points [00, 10, 01, 11]
376+
/// // The two-variate polynomial p = x_0 + 3 * x_0 * x_1 + 2 evaluates to [2, 3, 2, 6]
377+
/// // in the two-dimensional hypercube with points [00, 10, 01, 11]:
378+
/// // p(x_0, x_1) = 2*(1-x_1)*(1-x_0) + 3*(1-x_1)*x_0 + 2*x_1*(1-x_0) + 6*x_1*x_0
303379
/// let mle = DenseMultilinearExtension::from_evaluations_vec(
304380
/// 2, vec![2, 3, 2, 6].iter().map(|x| Fr::from(*x as u64)).collect()
305381
/// );
306382
///
307383
/// // By the uniqueness of MLEs, `mle` is precisely the above polynomial, which
308-
/// // takes the value 54 at the point (1, 17)
384+
/// // takes the value 54 at the point (x_0, x_1) = (1, 17)
309385
/// let eval = mle.evaluate(&[Fr::one(), Fr::from(17)].into());
310386
/// assert_eq!(eval, Fr::from(54));
311387
/// ```
@@ -441,4 +517,72 @@ mod tests {
441517
}
442518
}
443519
}
520+
521+
#[test]
522+
fn concat_two_equal_polys() {
523+
let mut rng = test_rng();
524+
let degree = 10;
525+
526+
let poly_l = DenseMultilinearExtension::rand(degree, &mut rng);
527+
let poly_r = DenseMultilinearExtension::rand(degree, &mut rng);
528+
529+
let merged = DenseMultilinearExtension::concat(&[&poly_l, &poly_r]);
530+
for _ in 0..10 {
531+
let point: Vec<_> = (0..(degree + 1)).map(|_| Fr::rand(&mut rng)).collect();
532+
533+
let expected = (Fr::ONE - point[10]) * poly_l.evaluate(&point[..10].to_vec())
534+
+ point[10] * poly_r.evaluate(&point[..10].to_vec());
535+
assert_eq!(expected, merged.evaluate(&point));
536+
}
537+
}
538+
539+
#[test]
540+
fn concat_unequal_polys() {
541+
let mut rng = test_rng();
542+
let degree = 10;
543+
let poly_l = DenseMultilinearExtension::rand(degree, &mut rng);
544+
// smaller poly
545+
let poly_r = DenseMultilinearExtension::rand(degree - 1, &mut rng);
546+
547+
let merged = DenseMultilinearExtension::concat(&[&poly_l, &poly_r]);
548+
549+
for _ in 0..10 {
550+
let point: Vec<_> = (0..(degree + 1)).map(|_| Fr::rand(&mut rng)).collect();
551+
552+
// merged poly is (1-x_10)*poly_l + x_10*((1-x_9)*poly_r1 + x_9*poly_r2).
553+
// where poly_r1 is poly_r, and poly_r2 is all zero, since we are padding.
554+
let expected = (Fr::ONE - point[10]) * poly_l.evaluate(&point[..10].to_vec())
555+
+ point[10] * ((Fr::ONE - point[9]) * poly_r.evaluate(&point[..9].to_vec()));
556+
assert_eq!(expected, merged.evaluate(&point));
557+
}
558+
}
559+
560+
#[test]
561+
fn concat_two_iterators() {
562+
let mut rng = test_rng();
563+
let degree = 10;
564+
565+
// rather than merging two polynomials, we merge two iterators of polynomials
566+
let polys_l: Vec<_> = (0..2)
567+
.map(|_| DenseMultilinearExtension::rand(degree - 2, &mut test_rng()))
568+
.collect();
569+
let polys_r: Vec<_> = (0..2)
570+
.map(|_| DenseMultilinearExtension::rand(degree - 2, &mut test_rng()))
571+
.collect();
572+
573+
let merged = DenseMultilinearExtension::<Fr>::concat(polys_l.iter().chain(polys_r.iter()));
574+
575+
for _ in 0..10 {
576+
let point: Vec<_> = (0..(degree)).map(|_| Fr::rand(&mut rng)).collect();
577+
578+
let expected = (Fr::ONE - point[9])
579+
* ((Fr::ONE - point[8]) * polys_l[0].evaluate(&point[..8].to_vec())
580+
+ point[8] * polys_l[1].evaluate(&point[..8].to_vec()))
581+
+ point[9]
582+
* ((Fr::ONE - point[8]) * polys_r[0].evaluate(&point[..8].to_vec())
583+
+ point[8] * polys_r[1].evaluate(&point[..8].to_vec()));
584+
585+
assert_eq!(expected, merged.evaluate(&point));
586+
}
587+
}
444588
}

0 commit comments

Comments
 (0)