Skip to content

compiler-rt: alu: add saturated shift left for i8 #134 #140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler-rt/src/alu.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ pub mod and;
pub mod or;
pub mod xor;
pub mod shl;
pub mod ushl_sat;
pub mod sshl_sat;
pub mod lshr;
pub mod fshl;
pub mod fshr;
Expand Down
16 changes: 2 additions & 14 deletions compiler-rt/src/alu/ashr/ashr_i8.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,10 @@ mod tests {
// this instruction returns a poison value.
//
// As per `docs/ALU Design.md`, poison values are not supported.
pub const test_cases_panic: [TestCaseTwoArgs; 8] = [
TestCaseTwoArgs { lhs: 0b11111111, rhs: 8, expected: 0b11111111 },
TestCaseTwoArgs { lhs: 0b11111111, rhs: 9, expected: 0b11111111 },
TestCaseTwoArgs { lhs: 0b11111111, rhs: 90, expected: 0b11111111 },
TestCaseTwoArgs { lhs: 0b11111111, rhs: 123, expected: 0b11111111 },
TestCaseTwoArgs { lhs: 0b00000000, rhs: 8, expected: 0b00000000 },
TestCaseTwoArgs { lhs: 0b00000000, rhs: 9, expected: 0b00000000 },
TestCaseTwoArgs { lhs: 0b00000000, rhs: 90, expected: 0b00000000 },
TestCaseTwoArgs { lhs: 0b00000000, rhs: 123, expected: 0b00000000 },
];

#[test]
#[should_panic(expected: "Requested shift by more bits than input word size")]
fn test_i8_panic() {
for case in test_cases_panic.span() {
assert_eq!(__llvm_ashr_i8_i8(*case.lhs, *case.rhs), *case.expected);
}
let case = TestCaseTwoArgs { lhs: 0b11111111, rhs: 8, expected: 0b00000000 };
assert_eq!(__llvm_ashr_i8_i8(case.lhs, case.rhs), case.expected);
}
}
2 changes: 1 addition & 1 deletion compiler-rt/src/alu/ctlz.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use core::num::traits::{BitSize, Bounded};
//
// This is a generic implementation for every data type. Its specialized versions
// are defined and tested in the ctlz/ctlz_<type>.cairo files.
fn ctlz<
pub fn ctlz<
T,
// The trait bounds are chosen so that:
//
Expand Down
13 changes: 5 additions & 8 deletions compiler-rt/src/alu/shl.cairo
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use core::num::traits::OverflowingMul;
pub mod shl_i1;
pub mod shl_i8;
pub mod shl_i16;
Expand All @@ -12,10 +13,6 @@ use core::num::traits::{BitSize, Bounded};
//
// This is a generic implementation for every data type. Its specialized versions
// are defined and tested in the shl/shl_<type>.cairo files.
//
// Please note that this implementation is valid up to 64-bit values due to extra bits
// needed to accommodate overflows that happen during computation. A 128-bit implementation
// uses different approach.
pub fn shl<
T,
// The trait bounds are chosen so that:
Expand All @@ -42,10 +39,10 @@ pub fn shl<
let mut result = n;
// Perform the shift `shift`` number of times.
for _ in 0..shift {
result = result * 2;
// Make sure the result is limited only to the bit width of the concrete type.
result = result & Bounded::<T>::MAX.into();
let (r, _) = result.overflowing_mul(2);
result = r
};

result
// Make sure the result is limited only to the bit width of the concrete type.
result & Bounded::<T>::MAX.into()
}
37 changes: 2 additions & 35 deletions compiler-rt/src/alu/shl/shl_i128.cairo
Original file line number Diff line number Diff line change
@@ -1,40 +1,7 @@
use core::num::traits::{BitSize, WrappingAdd, Bounded};
use crate::utils::assert_fits_in_type;
use crate::alu::shl::shl;

pub fn __llvm_shl_i128_i128(n: u128, shift: u128) -> u128 {
// The generic shl::<T> function does not handle 128-bit values properly, hence a more complex
// approach is used here.

// Make sure the value passed in the u128 arguments can fit in the concrete type.
assert_fits_in_type::<u128>(n);
assert_fits_in_type::<u128>(shift);

// Cairo does not have << or >> operators so we must implement the shift manually.
let mut result = n;
// Perform the shift `shift`` number of times.
for _ in 0..shift {
// Initialize new_result to 0 for the current shift.
let mut new_result = 0;
// Initialize mask to 0b0000..1 (it will move to the left so we can check each bit).
let mut mask = 1;

// Iterate through each bit position of the integer.
for _ in 0..BitSize::<u128>::bits() {
if result & mask != 0 {
// If the current bit is set, set the corresponding bit in new_result,
// but shifted one position to the left.
//
// mask.wrapping_add(mask) is essentially mask * 2 or mask << 1
// with the benefit of wrapping back at 0 when we reach the MSB.
new_result = new_result | mask.wrapping_add(mask);
}
mask = mask.wrapping_add(mask);
};
result = new_result;
};

// Make sure the result is limited only to the bit width of the concrete type.
result & Bounded::<u128>::MAX.into()
shl::<u128>(n, shift)
}

#[cfg(test)]
Expand Down
21 changes: 3 additions & 18 deletions compiler-rt/src/alu/smul_with_overflow.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
pub mod smul_with_overflow_i8;

use crate::utils::{assert_fits_in_type};
use crate::utils::{assert_fits_in_type, extend_sign};
use crate::alu::shl::shl;
use core::num::traits::{BitSize, Bounded, OverflowingMul};

Expand Down Expand Up @@ -37,24 +37,9 @@ fn smul_with_overflow<
let value_mask = sign_bit_mask - 1;
let sign_ext_bit_mask = ~value_mask;

// Function performing sign extension. This is needed, because the polyfill API
// requires operands to be u128, despite the actual value can be e.g. i8.
// In such case the remaining MSBs of u128 are zero, even if the operand is negative
// and should be sign-extended.
let extend_sign = |value: u128,
sign_bit: bool| -> u128 {
if sign_bit {
sign_ext_bit_mask | value
} else {
value
}
};

// Extend signs of operands if necessary.
let lhs_sign_bit = (lhs & sign_bit_mask) != 0;
let lhs = extend_sign(lhs, lhs_sign_bit);
let rhs_sign_bit = (rhs & sign_bit_mask) != 0;
let rhs = extend_sign(rhs, rhs_sign_bit);
let lhs = extend_sign(lhs, sign_bit_mask);
let rhs = extend_sign(rhs, sign_bit_mask);

// Perform the multiplication and check for overflow.
let (result, overflow) = lhs.overflowing_mul(rhs);
Expand Down
103 changes: 103 additions & 0 deletions compiler-rt/src/alu/sshl_sat.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
pub mod sshl_sat_i8;

use crate::utils::{assert_fits_in_type, extend_sign};
use crate::alu::shl::shl;
use core::num::traits::{BitSize, Bounded};

// Perform the `sshl_sat` operation.
//
// This function performs signed saturating shift left. It behaves like a regular
// bitwise shift left with the additional behavior of:
// - clamping the output to the minimum possible value of a type, if the shifted
// value is less than the minimum possible value of a type.
// - clamping the output to the maximum possible value of a type, if the shifted
// value is larger than the maximum possible value of a type.
//
// The minimum and maximum values are determined with the assumption of the input
// value being a signed number. Therefore the MSB of the type is the sign bit.
// Bitwise, the minimum value is 0b10..00 and the maximum value is 0b01..11.
//
// The shift value cannot be equal higher than the bit width of the concrete type.
// E.g. for `n` being an 8-bit value, the maximum allowed `shift` is 7. In LLVM IR
// shifting by more bits than the bit width of the input value results in returning
// a poison value. As for now, Hieratika support poisons values by panicking.
//
// This is a generic implementation for every data type. Its specialized versions
// are defined and tested in the sshl_sat/sshl_sat_<type>.cairo files.
fn sshl_sat<
T,
// The trait bounds are chosen so that:
//
// - BitSize<T>: we can determine the length of the data type in bits,
// - Bounded<T>: we can determine min and max value of the type,
// - TryInto<u128, T>, Into<T, u128> - we can convert the type from/to u128,
// - Destruct<T>: the type can be dropped as the result of the downcasting check.
//
// Overall these trait bounds allow any unsigned integer to be used as the concrete type.
impl TBitSize: BitSize<T>,
impl TBounded: Bounded<T>,
impl TTryInto: TryInto<u128, T>,
impl TInto: Into<T, u128>,
impl TDestruct: Destruct<T>,
>(
n: u128, shift: u128,
) -> u128 {
// Make sure the value passed in the u128 arguments can fit in the concrete type.
assert_fits_in_type::<T>(n);

// As per the LLVM Language Reference Manual:
//
// If b is (statically or dynamically) equal to or larger than the number of bits in op1,
// this instruction returns a poison value.
//
// As per `docs/ALU Design.md`, poison values cause panics.
let bit_size = BitSize::<T>::bits().into();
if shift >= bit_size {
panic!("Requested shift by more bits than input word size")
}

if n == 0 {
return 0;
}

if shift == 0 {
return n;
}

let shifted = shl::<u128>(n, shift);

// Check if the shifted value is negative
let sign_bit_mask = shl::<u128>(1, bit_size - 1);
let is_shifted_negative = (shifted & sign_bit_mask) != 0;
let is_n_negative = (n & sign_bit_mask) != 0;

// Min/max values of iN
let max_value = sign_bit_mask - 1;
let min_value = sign_bit_mask;
#[cairofmt::skip]
let result = match (is_n_negative, is_shifted_negative) {
(false, false) => {
if shifted > max_value {
max_value
} else {
shifted
}
},
(false, true) => {
max_value
},
(true, false) => {
min_value
},
(true, true) => {
let shifted_sign_bit_mask = shl::<u128>(1, bit_size - 1 + shift);
if extend_sign(shifted, shifted_sign_bit_mask) > extend_sign(min_value, sign_bit_mask) {
shifted
} else {
min_value
}
},
};

result & Bounded::<T>::MAX.into()
}
126 changes: 126 additions & 0 deletions compiler-rt/src/alu/sshl_sat/sshl_sat_i8.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
use crate::alu::sshl_sat::sshl_sat;

pub fn __llvm_sshl_sat_i8_i8(n: u128, shift: u128) -> u128 {
sshl_sat::<u8>(n, shift)
}

#[cfg(test)]
mod tests {
use super::__llvm_sshl_sat_i8_i8;
use crate::alu::test_case::TestCaseTwoArgs;
#[cairofmt::skip]
pub const test_cases: [TestCaseTwoArgs; 64] = [
// All possible shifts of -128 from 0 throughout the whole input value length.
// Since -128 is the lowest possible value, the output saturates
// at the minimum value.
TestCaseTwoArgs { lhs: 0b10000000, rhs: 0, expected: 0b10000000 },
TestCaseTwoArgs { lhs: 0b10000000, rhs: 1, expected: 0b10000000 },
TestCaseTwoArgs { lhs: 0b10000000, rhs: 2, expected: 0b10000000 },
TestCaseTwoArgs { lhs: 0b10000000, rhs: 3, expected: 0b10000000 },
TestCaseTwoArgs { lhs: 0b10000000, rhs: 4, expected: 0b10000000 },
TestCaseTwoArgs { lhs: 0b10000000, rhs: 5, expected: 0b10000000 },
TestCaseTwoArgs { lhs: 0b10000000, rhs: 6, expected: 0b10000000 },
TestCaseTwoArgs { lhs: 0b10000000, rhs: 7, expected: 0b10000000 },

// All possible shifts of -86 from 0 throughout the whole input value length.
// -86 << 1 == -86 * 2 == -172 < -128, so the output saturates
// at the minimum value.
TestCaseTwoArgs { lhs: 0b10101010, rhs: 0, expected: 0b10101010 },
TestCaseTwoArgs { lhs: 0b10101010, rhs: 1, expected: 0b10000000 },
TestCaseTwoArgs { lhs: 0b10101010, rhs: 2, expected: 0b10000000 },
TestCaseTwoArgs { lhs: 0b10101010, rhs: 3, expected: 0b10000000 },
TestCaseTwoArgs { lhs: 0b10101010, rhs: 4, expected: 0b10000000 },
TestCaseTwoArgs { lhs: 0b10101010, rhs: 5, expected: 0b10000000 },
TestCaseTwoArgs { lhs: 0b10101010, rhs: 6, expected: 0b10000000 },
TestCaseTwoArgs { lhs: 0b10101010, rhs: 7, expected: 0b10000000 },

// All possible shifts of -1 from 0 throughout the whole input value length.
// The value is shifted all the way to -128, so the saturation does not
// occur (or it does, but it is s equal to the actual result of the shift).
TestCaseTwoArgs { lhs: 0b11111111, rhs: 0, expected: 0b11111111 },
TestCaseTwoArgs { lhs: 0b11111111, rhs: 1, expected: 0b11111110 },
TestCaseTwoArgs { lhs: 0b11111111, rhs: 2, expected: 0b11111100 },
TestCaseTwoArgs { lhs: 0b11111111, rhs: 3, expected: 0b11111000 },
TestCaseTwoArgs { lhs: 0b11111111, rhs: 4, expected: 0b11110000 },
TestCaseTwoArgs { lhs: 0b11111111, rhs: 5, expected: 0b11100000 },
TestCaseTwoArgs { lhs: 0b11111111, rhs: 6, expected: 0b11000000 },
TestCaseTwoArgs { lhs: 0b11111111, rhs: 7, expected: 0b10000000 },

// All possible shifts of 0 from 0 throughout the whole input value length.
// No saturation because the result is always zero.
TestCaseTwoArgs { lhs: 0b00000000, rhs: 0, expected: 0b00000000 },
TestCaseTwoArgs { lhs: 0b00000000, rhs: 1, expected: 0b00000000 },
TestCaseTwoArgs { lhs: 0b00000000, rhs: 2, expected: 0b00000000 },
TestCaseTwoArgs { lhs: 0b00000000, rhs: 3, expected: 0b00000000 },
TestCaseTwoArgs { lhs: 0b00000000, rhs: 4, expected: 0b00000000 },
TestCaseTwoArgs { lhs: 0b00000000, rhs: 5, expected: 0b00000000 },
TestCaseTwoArgs { lhs: 0b00000000, rhs: 6, expected: 0b00000000 },
TestCaseTwoArgs { lhs: 0b00000000, rhs: 7, expected: 0b00000000 },

// All possible shifts of 1 from 0 throughout the whole input value length.
// 1 << 7 == 128 > 127, so the output saturates at the maximum value.
TestCaseTwoArgs { lhs: 0b00000001, rhs: 0, expected: 0b00000001 },
TestCaseTwoArgs { lhs: 0b00000001, rhs: 1, expected: 0b00000010 },
TestCaseTwoArgs { lhs: 0b00000001, rhs: 2, expected: 0b00000100 },
TestCaseTwoArgs { lhs: 0b00000001, rhs: 3, expected: 0b00001000 },
TestCaseTwoArgs { lhs: 0b00000001, rhs: 4, expected: 0b00010000 },
TestCaseTwoArgs { lhs: 0b00000001, rhs: 5, expected: 0b00100000 },
TestCaseTwoArgs { lhs: 0b00000001, rhs: 6, expected: 0b01000000 },
TestCaseTwoArgs { lhs: 0b00000001, rhs: 7, expected: 0b01111111 },

// All possible shifts of 15 from 0 throughout the whole input value length.
// No saturation up to 15 << 3 == 120 < 127. Saturation at the maximum
// value occurs at 15 << 4 == 240 > 127.
TestCaseTwoArgs { lhs: 0b00001111, rhs: 0, expected: 0b00001111 },
TestCaseTwoArgs { lhs: 0b00001111, rhs: 1, expected: 0b00011110 },
TestCaseTwoArgs { lhs: 0b00001111, rhs: 2, expected: 0b00111100 },
TestCaseTwoArgs { lhs: 0b00001111, rhs: 3, expected: 0b01111000 },
TestCaseTwoArgs { lhs: 0b00001111, rhs: 4, expected: 0b01111111 },
TestCaseTwoArgs { lhs: 0b00001111, rhs: 5, expected: 0b01111111 },
TestCaseTwoArgs { lhs: 0b00001111, rhs: 6, expected: 0b01111111 },
TestCaseTwoArgs { lhs: 0b00001111, rhs: 7, expected: 0b01111111 },

// All possible shifts of 85 from 0 throughout the whole input value length.
// 85 << 1 == 85 * 2 == 170 > 127, so the output saturates
// at the minimum value.
TestCaseTwoArgs { lhs: 0b01010101, rhs: 0, expected: 0b01010101 },
TestCaseTwoArgs { lhs: 0b01010101, rhs: 1, expected: 0b01111111 },
TestCaseTwoArgs { lhs: 0b01010101, rhs: 2, expected: 0b01111111 },
TestCaseTwoArgs { lhs: 0b01010101, rhs: 3, expected: 0b01111111 },
TestCaseTwoArgs { lhs: 0b01010101, rhs: 4, expected: 0b01111111 },
TestCaseTwoArgs { lhs: 0b01010101, rhs: 5, expected: 0b01111111 },
TestCaseTwoArgs { lhs: 0b01010101, rhs: 6, expected: 0b01111111 },
TestCaseTwoArgs { lhs: 0b01010101, rhs: 7, expected: 0b01111111 },

// All possible shifts of 127 from 0 throughout the whole input value length.
// Since 127 is the highest possible value, the output saturates
// at the maximum value.
TestCaseTwoArgs { lhs: 0b01111111, rhs: 0, expected: 0b01111111 },
TestCaseTwoArgs { lhs: 0b01111111, rhs: 1, expected: 0b01111111 },
TestCaseTwoArgs { lhs: 0b01111111, rhs: 2, expected: 0b01111111 },
TestCaseTwoArgs { lhs: 0b01111111, rhs: 3, expected: 0b01111111 },
TestCaseTwoArgs { lhs: 0b01111111, rhs: 4, expected: 0b01111111 },
TestCaseTwoArgs { lhs: 0b01111111, rhs: 5, expected: 0b01111111 },
TestCaseTwoArgs { lhs: 0b01111111, rhs: 6, expected: 0b01111111 },
TestCaseTwoArgs { lhs: 0b01111111, rhs: 7, expected: 0b01111111 },
];
#[test]
fn test_i8() {
for case in test_cases.span() {
assert_eq!(__llvm_sshl_sat_i8_i8(*case.lhs, *case.rhs), *case.expected);
}
}

// As per the LLVM Language Reference Manual:
//
// If b is (statically or dynamically) equal to or larger than the number of bits in op1,
// this instruction returns a poison value.
//
// As per `docs/ALU Design.md`, poison values cause panics.
#[test]
#[should_panic(expected: "Requested shift by more bits than input word size")]
fn test_i8_panic() {
let case = TestCaseTwoArgs { lhs: 0b11111111, rhs: 8, expected: 0b00000000 };
assert_eq!(__llvm_sshl_sat_i8_i8(case.lhs, case.rhs), case.expected);
}
}
Loading
Loading