Skip to content

Commit

Permalink
feat(integer): add reverse_bits
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeul-zama committed Jul 25, 2024
1 parent 19dc0f0 commit 434fbad
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 0 deletions.
36 changes: 36 additions & 0 deletions tfhe/src/high_level_api/integers/signed/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,42 @@ where
{
self.ciphertext.on_cpu().decrypt_trivial()
}

/// Reverse the bit of the signed integer
///
/// # Example
///
/// ```rust
/// use tfhe::prelude::*;
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheInt8};
///
/// let (client_key, server_key) = generate_keys(ConfigBuilder::default());
/// set_server_key(server_key);
///
/// let msg = 0b0110100_i8;
///
/// let a = FheInt8::encrypt(msg, &client_key);
///
/// let result: FheInt8 = a.reverse_bits();
///
/// let decrypted: i8 = result.decrypt(&client_key);
/// assert_eq!(decrypted, msg.reverse_bits());
/// ```
pub fn reverse_bits(&self) -> Self {
global_state::with_internal_keys(|key| match key {
InternalServerKey::Cpu(cpu_key) => {
let sk = &cpu_key.pbs_key();

let ct = self.ciphertext.on_cpu();

Self::new(RadixCiphertext::Cpu(sk.reverse_bits_parallelized(&*ct)))
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support reverse yet");
}
})
}
}

impl<FromId, IntoId> CastFrom<FheInt<FromId>> for FheInt<IntoId>
Expand Down
36 changes: 36 additions & 0 deletions tfhe/src/high_level_api/integers/unsigned/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,42 @@ where
}
})
}

/// Reverse the bit of the unsigned integer
///
/// # Example
///
/// ```rust
/// use tfhe::prelude::*;
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint8};
///
/// let (client_key, server_key) = generate_keys(ConfigBuilder::default());
/// set_server_key(server_key);
///
/// let msg = 0b10110100_u8;
///
/// let a = FheUint8::encrypt(msg, &client_key);
///
/// let result: FheUint8 = a.reverse_bits();
///
/// let decrypted: u8 = result.decrypt(&client_key);
/// assert_eq!(decrypted, msg.reverse_bits());
/// ```
pub fn reverse_bits(&self) -> Self {
global_state::with_internal_keys(|key| match key {
InternalServerKey::Cpu(cpu_key) => {
let sk = &cpu_key.pbs_key();

let ct = self.ciphertext.on_cpu();

Self::new(RadixCiphertext::Cpu(sk.reverse_bits_parallelized(&*ct)))
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support reverse yet");
}
})
}
}

impl<Id> TryFrom<crate::integer::RadixCiphertext> for FheUint<Id>
Expand Down
1 change: 1 addition & 0 deletions tfhe/src/integer/server_key/radix_parallel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub(crate) mod sub;
mod sum;

mod ilog2;
mod reverse_bits;
#[cfg(test)]
pub(crate) mod tests_cases_unsigned;
#[cfg(test)]
Expand Down
126 changes: 126 additions & 0 deletions tfhe/src/integer/server_key/radix_parallel/reverse_bits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
use super::ServerKey;
use crate::integer::ciphertext::IntegerRadixCiphertext;
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};

impl ServerKey {
/// Reverse the bits of the integer
///
/// # Example
///
///```rust
/// use tfhe::integer::{gen_keys_radix, IntegerCiphertext};
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
///
/// let num_blocks = 4;
///
/// // Generate the client key and the server key:
/// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks);
///
/// let msg = 0b10110100_u8;
///
/// let ct = cks.encrypt(msg);
///
/// // Compute homomorphically an addition:
/// let mut ct_res = sks.reverse_bits_parallelized(&ct);
///
/// // Decrypt:
/// let res: u8 = cks.decrypt(&ct_res);
/// assert_eq!(msg.reverse_bits(), res);
/// ```
pub fn reverse_bits_parallelized<T>(&self, ct: &T) -> T
where
T: IntegerRadixCiphertext,
{
let message_modulus = self.message_modulus().0 as u64;

let mut clean_ct;

let ct = if ct.block_carries_are_empty() {
ct
} else {
clean_ct = ct.clone();
self.full_propagate_parallelized(&mut clean_ct);
&clean_ct
};

let lut = self.key.generate_lookup_table(|x| {
(x % message_modulus).reverse_bits() >> (64 - message_modulus.ilog2())
});

let blocks = ct
.blocks()
.par_iter()
.rev()
.map(|block| self.key.apply_lookup_table(block, &lut))
.collect();

T::from_blocks(blocks)
}
}

#[cfg(test)]
mod tests {
use super::ServerKey;
use crate::integer::ciphertext::RadixCiphertext;
use crate::integer::keycache::KEY_CACHE;
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor;
use crate::integer::tests::create_parametrized_test;
use crate::integer::{IntegerKeyKind, RadixClientKey};
#[cfg(tarpaulin)]
use crate::shortint::parameters::coverage_parameters::*;
use crate::shortint::parameters::*;
use crate::shortint::PBSParameters;
use rand::prelude::*;
use std::sync::Arc;

pub(crate) fn reverse_bits_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,
T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>,
{
let param = param.into();
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let sks = Arc::new(sks);

let nb_blocks = 4;

let cks = RadixClientKey::from((cks, nb_blocks));

executor.setup(&cks, sks);

let log_modulus = nb_blocks * param.message_modulus().0.ilog2() as usize;
let modulus = 1 << log_modulus;

let nb_tests = 10;

let mut rng = rand::thread_rng();

for _ in 0..nb_tests {
let clear = rng.gen::<u64>() % modulus;

let ct = cks.encrypt(clear);

let result = executor.execute(&ct);
let decrypted_result: u64 = cks.decrypt(&result);

let expected_result = clear.reverse_bits() >> (64 - log_modulus);

assert_eq!(
decrypted_result, expected_result,
"Invalid reverse_bits result, gave clear = {clear}, \
expected {expected_result}, got {decrypted_result}"
);
}
}

fn integer_reverse_bits<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::reverse_bits_parallelized);
reverse_bits_test(param, executor);
}

create_parametrized_test!(integer_reverse_bits);
}

0 comments on commit 434fbad

Please sign in to comment.