Skip to content

Commit

Permalink
Added WideSquare trait. (#6194)
Browse files Browse the repository at this point in the history
  • Loading branch information
orizi authored Aug 20, 2024
1 parent c803c43 commit 86c6cae
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 3 deletions.
2 changes: 1 addition & 1 deletion corelib/src/integer.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,7 @@ pub fn u256_wide_mul(a: u256, b: u256) -> u512 nopanic {

/// Helper function for implementation of `u256_wide_mul`.
/// Used for adding two u128s and receiving a BoundedInt for the carry result.
fn u128_add_with_bounded_int_carry(
pub(crate) fn u128_add_with_bounded_int_carry(
a: u128, b: u128
) -> (u128, core::internal::bounded_int::BoundedInt<0, 1>) nopanic {
match u128_overflowing_add(a, b) {
Expand Down
1 change: 1 addition & 0 deletions corelib/src/num/traits.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ pub use ops::wrapping::{WrappingAdd, WrappingSub, WrappingMul};
pub use ops::checked::{CheckedAdd, CheckedSub, CheckedMul};
pub use ops::saturating::{SaturatingAdd, SaturatingSub, SaturatingMul};
pub use ops::widemul::WideMul;
pub use ops::widesquare::WideSquare;
pub use ops::sqrt::Sqrt;
1 change: 1 addition & 0 deletions corelib/src/num/traits/ops.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pub mod checked;
pub mod saturating;
pub(crate) mod sqrt;
pub(crate) mod widemul;
pub(crate) mod widesquare;
61 changes: 61 additions & 0 deletions corelib/src/num/traits/ops/widesquare.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use core::num::traits::WideMul;

/// A trait for a type that can be squared to produce a wider type.
pub trait WideSquare<T> {
/// The type of the result of the square.
type Target;
/// Calculates the square, producing a wider type.
fn wide_square(self: T) -> Self::Target;
}

mod wide_mul_based {
pub impl TWideSquare<T, impl TWideMul: super::WideMul<T, T>, +Copy<T>> of super::WideSquare<T> {
type Target = TWideMul::Target;
fn wide_square(self: T) -> Self::Target {
TWideMul::wide_mul(self, self)
}
}
}

impl WideSquareI8 = wide_mul_based::TWideSquare<i8>;
impl WideSquareI16 = wide_mul_based::TWideSquare<i16>;
impl WideSquareI32 = wide_mul_based::TWideSquare<i32>;
impl WideSquareI64 = wide_mul_based::TWideSquare<i64>;
impl WideSquareU8 = wide_mul_based::TWideSquare<u8>;
impl WideSquareU16 = wide_mul_based::TWideSquare<u16>;
impl WideSquareU32 = wide_mul_based::TWideSquare<u32>;
impl WideSquareU64 = wide_mul_based::TWideSquare<u64>;
impl WideSquareU128 = wide_mul_based::TWideSquare<u128>;
impl WideSquareU256 of WideSquare<u256> {
type Target = core::integer::u512;
fn wide_square(self: u256) -> Self::Target {
inner::u256_wide_square(self)
}
}

mod inner {
use core::integer::{u512, u128_add_with_bounded_int_carry, upcast};
use core::internal::bounded_int;
use core::num::traits::{WideSquare, WideMul, WrappingAdd};

pub fn u256_wide_square(value: u256) -> u512 {
let u256 { high: limb1, low: limb0 } = value.low.wide_square();
let u256 { high: limb2, low: limb1_part } = value.low.wide_mul(value.high);
let (limb1, limb1_overflow0) = u128_add_with_bounded_int_carry(limb1, limb1_part);
let (limb1, limb1_overflow1) = u128_add_with_bounded_int_carry(limb1, limb1_part);
let (limb2, limb2_overflow0) = u128_add_with_bounded_int_carry(limb2, limb2);
let u256 { high: limb3, low: limb2_part } = value.high.wide_square();
let (limb2, limb2_overflow1) = u128_add_with_bounded_int_carry(limb2, limb2_part);
// Packing together the overflow bits, making a cheaper addition into limb2.
let limb1_overflow = bounded_int::add(limb1_overflow0, limb1_overflow1);
let (limb2, limb2_overflow2) = u128_add_with_bounded_int_carry(
limb2, upcast(limb1_overflow)
);
// Packing together the overflow bits, making a cheaper addition into limb3.
let limb2_overflow = bounded_int::add(limb2_overflow0, limb2_overflow1);
let limb2_overflow = bounded_int::add(limb2_overflow, limb2_overflow2);
// No overflow since no limb4.
let limb3 = limb3.wrapping_add(upcast(limb2_overflow));
u512 { limb0, limb1, limb2, limb3 }
}
}
27 changes: 25 additions & 2 deletions corelib/src/test/integer_test.cairo
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[feature("deprecated-bounded-int-trait")]
use core::{integer, integer::{u512_safe_div_rem_by_u256, u512}};
use core::test::test_utils::{assert_eq, assert_ne, assert_le, assert_lt, assert_gt, assert_ge};
use core::num::traits::{Bounded, Sqrt, WideMul, WrappingSub};
use core::num::traits::{Bounded, Sqrt, WideMul, WideSquare, WrappingSub};

#[test]
fn test_u8_operators() {
Expand Down Expand Up @@ -706,6 +706,29 @@ fn test_u256_wide_mul() {
);
}

#[test]
fn test_u256_wide_square() {
assert!(0_u256.wide_square() == u512 { limb0: 0, limb1: 0, limb2: 0, limb3: 0 });
assert!(
0x1001001001001001001001001001001001001001001001001001_u256
.wide_square() == u512 {
limb0: 0x0b00a009008007006005004003002001,
limb1: 0xe00f01001101201101000f00e00d00c0,
limb2: 0x00400500600700800900a00b00c00d00,
limb3: 0x1002003
}
);
assert!(
0x1000100010001000100010001000100010001000100010001000100010001_u256
.wide_square() == u512 {
limb0: 0x00080007000600050004000300020001,
limb1: 0x0010000f000e000d000c000b000a0009,
limb2: 0x00080009000a000b000c000d000e000f,
limb3: 0x1000200030004000500060007
}
);
}

#[test]
fn test_u512_safe_div_rem_by_u256() {
let zero = u512 { limb0: 0, limb1: 0, limb2: 0, limb3: 0 };
Expand Down Expand Up @@ -846,7 +869,7 @@ fn test_u256_sqrt() {
assert!(1_u256.sqrt() == 1);
assert!(0_u256.sqrt() == 0);
assert!(Bounded::<u256>::MAX.sqrt() == Bounded::<u128>::MAX);
assert!(Bounded::<u128>::MAX.wide_mul(Bounded::<u128>::MAX).sqrt() == Bounded::<u128>::MAX);
assert!(Bounded::<u128>::MAX.wide_square().sqrt() == Bounded::<u128>::MAX);
}

#[test]
Expand Down

0 comments on commit 86c6cae

Please sign in to comment.