From 2dc5c8a891e2e1967fd748cdab2acc59a0c8b1d3 Mon Sep 17 00:00:00 2001 From: sarah el kazdadi Date: Fri, 28 Jun 2024 11:33:22 +0200 Subject: [PATCH] feat(perf): optimize some custom mod ops --- .../commons/math/decomposition/decomposer.rs | 4 +- .../core_crypto/commons/numeric/unsigned.rs | 41 ++++++------------- 2 files changed, 13 insertions(+), 32 deletions(-) diff --git a/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs b/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs index e98c8c90a7..82ff9bf277 100644 --- a/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs +++ b/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs @@ -384,9 +384,7 @@ where let modulus_as_scalar: Scalar = self.ciphertext_modulus.get_custom_modulus().cast_into(); match sign { ValueSign::Positive => abs_closest, - ValueSign::Negative => { - modulus_as_scalar.wrapping_sub_custom_mod(abs_closest, modulus_as_scalar) - } + ValueSign::Negative => abs_closest.wrapping_neg_custom_mod(modulus_as_scalar), } } diff --git a/tfhe/src/core_crypto/commons/numeric/unsigned.rs b/tfhe/src/core_crypto/commons/numeric/unsigned.rs index ed730257e6..e52e74a5d0 100644 --- a/tfhe/src/core_crypto/commons/numeric/unsigned.rs +++ b/tfhe/src/core_crypto/commons/numeric/unsigned.rs @@ -153,37 +153,17 @@ macro_rules! implement { } #[inline] fn wrapping_add_custom_mod(self, other: Self, custom_modulus: Self) -> Self { - if Self::BITS <= 64 { - let self_u128: u128 = self.cast_into(); - let other_u128: u128 = other.cast_into(); - let custom_modulus_u128: u128 = custom_modulus.cast_into(); - self_u128 - .wrapping_add(other_u128) - .wrapping_rem(custom_modulus_u128) - .cast_into() - } else { - if custom_modulus.is_power_of_two() { - return self.wrapping_add(other).wrapping_rem(custom_modulus); - } - todo!("wrapping_add_custom_mod is not yet implemented for non power of two moduli wider than u64") - } + self.wrapping_sub_custom_mod( + other.wrapping_neg_custom_mod(custom_modulus), + custom_modulus, + ) } #[inline] fn wrapping_sub_custom_mod(self, other: Self, custom_modulus: Self) -> Self { - if Self::BITS <= 64 { - let self_u128: u128 = self.cast_into(); - let other_u128: u128 = other.cast_into(); - let custom_modulus_u128: u128 = custom_modulus.cast_into(); - self_u128 - .wrapping_add(custom_modulus_u128) - .wrapping_sub(other_u128) - .wrapping_rem(custom_modulus_u128) - .cast_into() + if self >= other { + self - other } else { - if custom_modulus.is_power_of_two() { - return self.wrapping_sub(other).wrapping_rem(custom_modulus); - } - todo!("wrapping_sub_custom_mod is not yet implemented for non power of two moduli wider than u64") + self.wrapping_sub(other).wrapping_add(custom_modulus) } } #[inline] @@ -218,8 +198,11 @@ macro_rules! implement { } #[inline] fn wrapping_neg_custom_mod(self, custom_modulus: Self) -> Self { - // Custom modulus applied by wrapping_sub - Self::ZERO.wrapping_sub_custom_mod(self, custom_modulus) + if self == Self::ZERO { + self + } else { + custom_modulus - self + } } #[inline] fn wrapping_shl(self, rhs: u32) -> Self {