Skip to content

Commit

Permalink
Merge pull request #2717 from o1-labs/marc/mvpoly/simplify-add
Browse files Browse the repository at this point in the history
MvPoly: reuse add by value for add by ref
  • Loading branch information
dannywillems authored Oct 16, 2024
2 parents 434ad30 + af17842 commit 1f8cf43
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 87 deletions.
68 changes: 4 additions & 64 deletions mvpoly/src/monomials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,80 +28,25 @@ impl<const N: usize, const D: usize, F: PrimeField> Add for Sparse<F, N, D> {
type Output = Self;

fn add(self, other: Self) -> Self {
let mut monomials = self.monomials.clone();
for (exponents, coeff) in other.monomials {
monomials
.entry(exponents)
.and_modify(|c| *c += coeff)
.or_insert(coeff);
}
// Remove monomials with zero coefficients
let monomials: HashMap<[usize; N], F> = monomials
.into_iter()
.filter(|(_, coeff)| !coeff.is_zero())
.collect();
// Handle the case where the result is zero because we want a unique
// representation
if monomials.is_empty() {
Self::zero()
} else {
Sparse::<F, N, D> { monomials }
}
&self + &other
}
}

impl<const N: usize, const D: usize, F: PrimeField> Add<&Sparse<F, N, D>> for Sparse<F, N, D> {
type Output = Sparse<F, N, D>;

fn add(self, other: &Sparse<F, N, D>) -> Self::Output {
let mut monomials = self.monomials.clone();
for (exponents, coeff) in &other.monomials {
monomials
.entry(*exponents)
.and_modify(|c| *c += *coeff)
.or_insert(*coeff);
}
// Remove monomials with zero coefficients
let monomials: HashMap<[usize; N], F> = monomials
.into_iter()
.filter(|(_, coeff)| !coeff.is_zero())
.collect();
// Handle the case where the result is zero because we want a unique
// representation
if monomials.is_empty() {
Self::zero()
} else {
Sparse::<F, N, D> { monomials }
}
&self + other
}
}

impl<const N: usize, const D: usize, F: PrimeField> Add<Sparse<F, N, D>> for &Sparse<F, N, D> {
type Output = Sparse<F, N, D>;

fn add(self, other: Sparse<F, N, D>) -> Self::Output {
let mut monomials = self.monomials.clone();
for (exponents, coeff) in other.monomials {
monomials
.entry(exponents)
.and_modify(|c| *c += coeff)
.or_insert(coeff);
}
// Remove monomials with zero coefficients
let monomials: HashMap<[usize; N], F> = monomials
.into_iter()
.filter(|(_, coeff)| !coeff.is_zero())
.collect();
// Handle the case where the result is zero because we want a unique
// representation
if monomials.is_empty() {
Sparse::<F, N, D>::zero()
} else {
Sparse::<F, N, D> { monomials }
}
self + &other
}
}

impl<const N: usize, const D: usize, F: PrimeField> Add<&Sparse<F, N, D>> for &Sparse<F, N, D> {
type Output = Sparse<F, N, D>;

Expand Down Expand Up @@ -190,12 +135,7 @@ impl<const N: usize, const D: usize, F: PrimeField> Neg for Sparse<F, N, D> {
type Output = Sparse<F, N, D>;

fn neg(self) -> Self::Output {
let monomials: HashMap<[usize; N], F> = self
.monomials
.into_iter()
.map(|(exponents, coeff)| (exponents, -coeff))
.collect();
Sparse::<F, N, D> { monomials }
-&self
}
}

Expand Down
27 changes: 4 additions & 23 deletions mvpoly/src/prime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -542,41 +542,23 @@ impl<F: PrimeField, const N: usize, const D: usize> Add for Dense<F, N, D> {
type Output = Self;

fn add(self, other: Self) -> Self {
let coeffs = self
.coeff
.iter()
.zip(other.coeff.iter())
.map(|(a, b)| *a + *b)
.collect();
Self::from_coeffs(coeffs)
&self + &other
}
}

impl<F: PrimeField, const N: usize, const D: usize> Add<&Dense<F, N, D>> for Dense<F, N, D> {
type Output = Dense<F, N, D>;

fn add(self, other: &Dense<F, N, D>) -> Dense<F, N, D> {
let coeffs = self
.coeff
.iter()
.zip(other.coeff.iter())
.map(|(a, b)| *a + *b)
.collect();
Self::from_coeffs(coeffs)
&self + other
}
}

impl<F: PrimeField, const N: usize, const D: usize> Add<Dense<F, N, D>> for &Dense<F, N, D> {
type Output = Dense<F, N, D>;

fn add(self, other: Dense<F, N, D>) -> Dense<F, N, D> {
let coeffs = self
.coeff
.iter()
.zip(other.coeff.iter())
.map(|(a, b)| *a + *b)
.collect();
Dense::from_coeffs(coeffs)
self + &other
}
}

Expand Down Expand Up @@ -632,8 +614,7 @@ impl<F: PrimeField, const N: usize, const D: usize> Neg for Dense<F, N, D> {
type Output = Self;

fn neg(self) -> Self::Output {
let coeffs = self.coeff.iter().map(|c| -*c).collect();
Self::from_coeffs(coeffs)
-&self
}
}

Expand Down

0 comments on commit 1f8cf43

Please sign in to comment.