Skip to content

Commit

Permalink
Add Sum for nonnative
Browse files Browse the repository at this point in the history
  • Loading branch information
Pratyush committed Aug 11, 2021
1 parent c28febd commit a2d2d47
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 28 deletions.
6 changes: 3 additions & 3 deletions src/fields/fp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl<F: PrimeField> AllocatedFp<F> {
/// Add many allocated Fp elements together.
///
/// This does not create any constraints and only creates one linear combination.
pub fn addmany<'a, I: Iterator<Item = &'a Self>>(iter: I) -> Self {
pub fn add_many<'a, I: Iterator<Item = &'a Self>>(iter: I) -> Self {
let mut cs = ConstraintSystemRef::None;
let mut has_value = true;
let mut value = F::zero();
Expand Down Expand Up @@ -1062,7 +1062,7 @@ impl<F: PrimeField> AllocVar<F, F> for FpVar<F> {
impl<'a, F: PrimeField> Sum<&'a FpVar<F>> for FpVar<F> {
fn sum<I: Iterator<Item = &'a FpVar<F>>>(iter: I) -> FpVar<F> {
let mut sum_constants = F::zero();
let sum_variables = FpVar::Var(AllocatedFp::<F>::addmany(iter.filter_map(|x| match x {
let sum_variables = FpVar::Var(AllocatedFp::<F>::add_many(iter.filter_map(|x| match x {
FpVar::Constant(c) => {
sum_constants += c;
None
Expand All @@ -1087,7 +1087,7 @@ impl<F: PrimeField> Sum<FpVar<F>> for FpVar<F> {
FpVar::Var(v) => Some(v),
})
.collect::<Vec<_>>();
let sum_variables = FpVar::Var(AllocatedFp::<F>::addmany(vars.iter()));
let sum_variables = FpVar::Var(AllocatedFp::<F>::add_many(vars.iter()));

let sum = sum_variables + sum_constants;
sum
Expand Down
42 changes: 42 additions & 0 deletions src/fields/nonnative/allocated_field_var.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,48 @@ impl<TargetField: PrimeField, BaseField: PrimeField>
Ok(res)
}

/// Add many allocated elements together.
///
/// This does not create any constraints and only creates #limbs linear combinations.
///
/// If there are 0 items in the iterator, then this returns `Ok(None)`.
pub fn add_many<'a, I: Iterator<Item = &'a Self>>(
iter: I,
) -> Result<Option<Self>, SynthesisError> {
let mut limbs_iter = Vec::new();
let cs;
let mut num_of_additions_over_normal_form = BaseField::zero();
let is_in_the_normal_form = false;
if let Some(first) = iter.next() {
cs = first.cs();
for limb in &first.limbs {
limbs_iter.push(vec![limb]);
}
for elem in iter {
for (cur_limb, limbs) in elem.limbs.iter().zip(limbs_iter) {
limbs.push(cur_limb);
}
num_of_additions_over_normal_form += BaseField::one();
}
let limbs = limbs_iter
.into_iter()
.map(|limbs| limbs.into_iter().sum::<FpVar<_>>())
.collect::<Vec<_>>();

let result = Self {
cs,
limbs,
num_of_additions_over_normal_form,
is_in_the_normal_form,
target_phantom: PhantomData,
};
Reducer::<TargetField, BaseField>::post_add_reduce(&mut result)?;
Ok(Some(result))
} else {
Ok(None)
}
}

/// Subtract a nonnative field element, without the final reduction step
#[tracing::instrument(target = "r1cs")]
pub fn sub_without_reduce(&self, other: &Self) -> R1CSResult<Self> {
Expand Down
109 changes: 88 additions & 21 deletions src/fields/nonnative/field_var.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use ark_ff::{to_bytes, FpParameters};
use ark_relations::r1cs::Result as R1CSResult;
use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError};
use ark_std::hash::{Hash, Hasher};
use ark_std::{borrow::Borrow, vec::Vec};
use ark_std::{borrow::Borrow, iter::Sum, vec::Vec};

/// A gadget for representing non-native (`TargetField`) field elements over the constraint field (`BaseField`).
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -116,11 +116,30 @@ impl<TargetField: PrimeField, BaseField: PrimeField> FieldVar<TargetField, BaseF
}

#[tracing::instrument(target = "r1cs")]
fn negate(&self) -> R1CSResult<Self> {
match self {
Self::Constant(c) => Ok(Self::Constant(-*c)),
Self::Var(v) => Ok(Self::Var(v.negate()?)),
}
fn negate_in_place(&mut self) -> R1CSResult<&mut Self> {
*self = match self {
Self::Constant(c) => Self::Constant(-*c),
Self::Var(v) => Self::Var(v.negate()?),
};
Ok(self)
}

#[tracing::instrument(target = "r1cs")]
fn double_in_place(&mut self) -> R1CSResult<&mut Self> {
*self = match self {
Self::Constant(c) => Self::Constant(c.double()),
Self::Var(v) => Self::Var(v.add(&*v)?),
};
Ok(self)
}

#[tracing::instrument(target = "r1cs")]
fn square_in_place(&mut self) -> R1CSResult<&mut Self> {
*self = match self {
Self::Constant(c) => Self::Constant(c.square()),
Self::Var(v) => Self::Var(v.mul(&*v)?),
};
Ok(self)
}

#[tracing::instrument(target = "r1cs")]
Expand Down Expand Up @@ -154,15 +173,15 @@ impl_bounded_ops!(
add,
AddAssign,
add_assign,
|this: &'a NonNativeFieldVar<TargetField, BaseField>, other: &'a NonNativeFieldVar<TargetField, BaseField>| {
|this: &mut NonNativeFieldVar<TargetField, BaseField>, other: &'a NonNativeFieldVar<TargetField, BaseField>| {
use NonNativeFieldVar::*;
match (this, other) {
*this = match (&*this, other) {
(Constant(c1), Constant(c2)) => Constant(*c1 + c2),
(Constant(c), Var(v)) | (Var(v), Constant(c)) => Var(v.add_constant(c).unwrap()),
(Var(v1), Var(v2)) => Var(v1.add(v2).unwrap()),
}
};
},
|this: &'a NonNativeFieldVar<TargetField, BaseField>, other: TargetField| { this + &NonNativeFieldVar::Constant(other) },
|this: &mut NonNativeFieldVar<TargetField, BaseField>, other: TargetField| { *this = &*this + &NonNativeFieldVar::Constant(other) },
(TargetField: PrimeField, BaseField: PrimeField),
);

Expand All @@ -173,17 +192,17 @@ impl_bounded_ops!(
sub,
SubAssign,
sub_assign,
|this: &'a NonNativeFieldVar<TargetField, BaseField>, other: &'a NonNativeFieldVar<TargetField, BaseField>| {
|this: &mut NonNativeFieldVar<TargetField, BaseField>, other: &'a NonNativeFieldVar<TargetField, BaseField>| {
use NonNativeFieldVar::*;
match (this, other) {
*this = match (&*this, other) {
(Constant(c1), Constant(c2)) => Constant(*c1 - c2),
(Var(v), Constant(c)) => Var(v.sub_constant(c).unwrap()),
(Constant(c), Var(v)) => Var(v.sub_constant(c).unwrap().negate().unwrap()),
(Var(v1), Var(v2)) => Var(v1.sub(v2).unwrap()),
}
};
},
|this: &'a NonNativeFieldVar<TargetField, BaseField>, other: TargetField| {
this - &NonNativeFieldVar::Constant(other)
|this: &mut NonNativeFieldVar<TargetField, BaseField>, other: TargetField| {
*this = &*this - &NonNativeFieldVar::Constant(other)
},
(TargetField: PrimeField, BaseField: PrimeField),
);
Expand All @@ -195,20 +214,20 @@ impl_bounded_ops!(
mul,
MulAssign,
mul_assign,
|this: &'a NonNativeFieldVar<TargetField, BaseField>, other: &'a NonNativeFieldVar<TargetField, BaseField>| {
|this: &mut NonNativeFieldVar<TargetField, BaseField>, other: &'a NonNativeFieldVar<TargetField, BaseField>| {
use NonNativeFieldVar::*;
match (this, other) {
*this = match (&*this, other) {
(Constant(c1), Constant(c2)) => Constant(*c1 * c2),
(Constant(c), Var(v)) | (Var(v), Constant(c)) => Var(v.mul_constant(c).unwrap()),
(Var(v1), Var(v2)) => Var(v1.mul(v2).unwrap()),
}
},
|this: &'a NonNativeFieldVar<TargetField, BaseField>, other: TargetField| {
if other.is_zero() {
|this: &mut NonNativeFieldVar<TargetField, BaseField>, other: TargetField| {
*this = if other.is_zero() {
NonNativeFieldVar::zero()
} else {
this * &NonNativeFieldVar::Constant(other)
}
&*this * &NonNativeFieldVar::Constant(other)
};
},
(TargetField: PrimeField, BaseField: PrimeField),
);
Expand Down Expand Up @@ -454,6 +473,54 @@ impl<TargetField: PrimeField, BaseField: PrimeField> ToConstraintFieldGadget<Bas
}
}

impl<'a, TargetField: PrimeField, BaseField: PrimeField> Sum<&'a Self>
for NonNativeFieldVar<TargetField, BaseField>
{
fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
let mut sum_constants = TargetField::zero();
let vars = iter
.filter_map(|x| match x {
Self::Constant(c) => {
sum_constants += c;
None
}
Self::Var(v) => Some(v),
})
.collect::<Vec<_>>();
let sum_variables = AllocatedNonNativeFieldVar::add_many(vars.into_iter())
.unwrap()
.map(Self::Var)
.unwrap_or(Self::zero());

let sum = sum_variables + sum_constants;
sum
}
}

impl<TargetField: PrimeField, BaseField: PrimeField> Sum<Self>
for NonNativeFieldVar<TargetField, BaseField>
{
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
let mut sum_constants = TargetField::zero();
let vars = iter
.filter_map(|x| match x {
Self::Constant(c) => {
sum_constants += c;
None
}
Self::Var(v) => Some(v),
})
.collect::<Vec<_>>();
let sum_variables = AllocatedNonNativeFieldVar::add_many(vars.iter())
.unwrap()
.map(Self::Var)
.unwrap_or(Self::zero());

let sum = sum_variables + sum_constants;
sum
}
}

impl<TargetField: PrimeField, BaseField: PrimeField> NonNativeFieldVar<TargetField, BaseField> {
/// The `mul_without_reduce` for `NonNativeFieldVar`
#[tracing::instrument(target = "r1cs")]
Expand Down
8 changes: 4 additions & 4 deletions src/fields/nonnative/mul_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ impl_bounded_ops!(
add,
AddAssign,
add_assign,
|this: &'a NonNativeFieldMulResultVar<TargetField, BaseField>, other: &'a NonNativeFieldMulResultVar<TargetField, BaseField>| {
|this: &mut NonNativeFieldMulResultVar<TargetField, BaseField>, other: &'a NonNativeFieldMulResultVar<TargetField, BaseField>| {
use NonNativeFieldMulResultVar::*;
match (this, other) {
*this = match (&*this, other) {
(Constant(c1), Constant(c2)) => Constant(*c1 + c2),
(Constant(c), Var(v)) | (Var(v), Constant(c)) => Var(v.add_constant(c).unwrap()),
(Var(v1), Var(v2)) => Var(v1.add(v2).unwrap()),
}
};
},
|this: &'a NonNativeFieldMulResultVar<TargetField, BaseField>, other: TargetField| { this + &NonNativeFieldMulResultVar::Constant(other) },
|this: &mut NonNativeFieldMulResultVar<TargetField, BaseField>, other: TargetField| { *this = &*this + &NonNativeFieldMulResultVar::Constant(other) },
(TargetField: PrimeField, BaseField: PrimeField),
);

0 comments on commit a2d2d47

Please sign in to comment.