From 7305c1849542e7c04a0d240046c843482ba1cc99 Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Mon, 24 Feb 2025 17:22:06 +0100 Subject: [PATCH] fix(gpu): fix scalar comparisons with 1 block --- .../cuda/include/integer/integer_utilities.h | 1 + .../cuda/src/integer/scalar_comparison.cu | 35 ++- .../cuda/src/integer/scalar_comparison.cuh | 283 ++++++++++++------ tfhe/src/high_level_api/booleans/base.rs | 13 +- .../high_level_api/integers/signed/encrypt.rs | 24 +- .../tests_unsigned/test_scalar_comparison.rs | 106 +++++++ .../tests_unsigned/test_scalar_comparison.rs | 90 ++++++ 7 files changed, 448 insertions(+), 104 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h b/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h index ad842520f9..b9ed09f5f9 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h @@ -4728,4 +4728,5 @@ void update_degrees_after_scalar_bitxor(uint64_t *output_degrees, uint64_t *clear_degrees, uint64_t *input_degrees, uint32_t num_clear_blocks); +std::pair get_invert_flags(COMPARISON_TYPE compare); #endif // CUDA_INTEGER_UTILITIES_H diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cu b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cu index 3417754dca..43076b6ed0 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cu @@ -1,5 +1,36 @@ #include "integer/scalar_comparison.cuh" +#include +#include // for std::pair + +std::pair get_invert_flags(COMPARISON_TYPE compare) { + bool invert_operands; + bool invert_subtraction_result; + + switch (compare) { + case COMPARISON_TYPE::LT: + invert_operands = false; + invert_subtraction_result = false; + break; + case COMPARISON_TYPE::LE: + invert_operands = true; + invert_subtraction_result = true; + break; + case COMPARISON_TYPE::GT: + invert_operands = true; + invert_subtraction_result = false; + break; + case COMPARISON_TYPE::GE: + invert_operands = false; + invert_subtraction_result = true; + break; + default: + PANIC("Cuda error: invalid comparison type") + } + + return {invert_operands, invert_subtraction_result}; +} + void cuda_scalar_comparison_integer_radix_ciphertext_kb_64( void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, void *lwe_array_out, void const *lwe_array_in, void const *scalar_blocks, @@ -22,9 +53,9 @@ void cuda_scalar_comparison_integer_radix_ciphertext_kb_64( case GE: case LT: case LE: - if (lwe_ciphertext_count % 2 != 0) + if (lwe_ciphertext_count % 2 != 0 && lwe_ciphertext_count != 1) PANIC("Cuda error (scalar comparisons): the number of radix blocks has " - "to be even.") + "to be even or equal to 1.") host_integer_radix_scalar_difference_check_kb( (cudaStream_t *)(streams), gpu_indexes, gpu_count, static_cast(lwe_array_out), diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh index 9c0c6a95ca..b5dad2b20f 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh @@ -2,6 +2,27 @@ #define CUDA_INTEGER_SCALAR_COMPARISON_OPS_CUH #include "integer/comparison.cuh" +template +Torus is_x_less_than_y_given_input_borrow(Torus last_x_block, + Torus last_y_block, Torus borrow, + uint32_t message_modulus) { + Torus last_bit_pos = log2_int(message_modulus) - 1; + Torus mask = (1 << last_bit_pos) - 1; + Torus x_without_last_bit = last_x_block & mask; + Torus y_without_last_bit = last_y_block & mask; + + bool input_borrow_to_last_bit = + x_without_last_bit < (y_without_last_bit + borrow); + + Torus result = last_x_block - (last_y_block + borrow); + + Torus output_sign_bit = (result >> last_bit_pos) & 1; + bool output_borrow = last_x_block < (last_y_block + borrow); + + Torus overflow_flag = (Torus)(input_borrow_to_last_bit ^ output_borrow); + + return output_sign_bit ^ overflow_flag; +} template __host__ void scalar_compare_radix_blocks_kb( @@ -205,42 +226,79 @@ __host__ void integer_radix_unsigned_scalar_difference_check_kb( lwe_array_msb_out, bsks, ksks, 1, lut, lut->params.message_modulus); } else { - // We only have to do the regular comparison - // And not the part where we compare most significant blocks with zeros - // total_num_radix_blocks == total_num_scalar_blocks - uint32_t num_lsb_radix_blocks = total_num_radix_blocks; - uint32_t num_scalar_blocks = total_num_scalar_blocks; - - Torus *lhs = diff_buffer->tmp_packed; - Torus *rhs = - diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size; - - pack_blocks(streams[0], gpu_indexes[0], lhs, lwe_array_in, - big_lwe_dimension, num_lsb_radix_blocks, - message_modulus); - pack_blocks(streams[0], gpu_indexes[0], rhs, scalar_blocks, 0, - num_scalar_blocks, message_modulus); - - // From this point we have half number of blocks - num_lsb_radix_blocks /= 2; - num_scalar_blocks /= 2; - - // comparisons will be assigned - // - 0 if lhs < rhs - // - 1 if lhs == rhs - // - 2 if lhs > rhs - auto comparisons = mem_ptr->tmp_lwe_array_out; - scalar_compare_radix_blocks_kb(streams, gpu_indexes, gpu_count, - comparisons, lhs, rhs, mem_ptr, bsks, - ksks, num_lsb_radix_blocks); - - // Reduces a vec containing radix blocks that encrypts a sign - // (inferior, equal, superior) to one single radix block containing the - // final sign - tree_sign_reduction(streams, gpu_indexes, gpu_count, lwe_array_out, - comparisons, mem_ptr->diff_buffer->tree_buffer, - sign_handler_f, bsks, ksks, - num_lsb_radix_blocks); + if (total_num_radix_blocks == 1) { + std::pair invert_flags = get_invert_flags(mem_ptr->op); + Torus scalar = 0; + cuda_memcpy_async_to_cpu(&scalar, scalar_blocks, sizeof(Torus), + streams[0], gpu_indexes[0]); + cuda_synchronize_stream(streams[0], gpu_indexes[0]); + auto one_block_lut_f = [invert_flags, scalar](Torus x) -> Torus { + Torus x_0; + Torus x_1; + if (invert_flags.first) { + x_0 = scalar; + x_1 = x; + } else { + x_0 = x; + x_1 = scalar; + } + auto overflowed = x_0 < x_1; + return (Torus)(invert_flags.second ^ overflowed); + }; + int_radix_lut *one_block_lut = new int_radix_lut( + streams, gpu_indexes, gpu_count, params, 1, 1, true); + + generate_device_accumulator( + streams[0], gpu_indexes[0], one_block_lut->get_lut(0, 0), + one_block_lut->get_degree(0), one_block_lut->get_max_degree(0), + params.glwe_dimension, params.polynomial_size, params.message_modulus, + params.carry_modulus, one_block_lut_f); + + one_block_lut->broadcast_lut(streams, gpu_indexes, 0); + + legacy_integer_radix_apply_univariate_lookup_table_kb( + streams, gpu_indexes, gpu_count, lwe_array_out, lwe_array_in, bsks, + ksks, 1, one_block_lut); + one_block_lut->release(streams, gpu_indexes, gpu_count); + delete one_block_lut; + } else { + // We only have to do the regular comparison + // And not the part where we compare most significant blocks with zeros + // total_num_radix_blocks == total_num_scalar_blocks + uint32_t num_lsb_radix_blocks = total_num_radix_blocks; + uint32_t num_scalar_blocks = total_num_scalar_blocks; + + Torus *lhs = diff_buffer->tmp_packed; + Torus *rhs = + diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size; + + pack_blocks(streams[0], gpu_indexes[0], lhs, lwe_array_in, + big_lwe_dimension, num_lsb_radix_blocks, + message_modulus); + pack_blocks(streams[0], gpu_indexes[0], rhs, scalar_blocks, 0, + num_scalar_blocks, message_modulus); + + // From this point we have half number of blocks + num_lsb_radix_blocks /= 2; + num_scalar_blocks /= 2; + + // comparisons will be assigned + // - 0 if lhs < rhs + // - 1 if lhs == rhs + // - 2 if lhs > rhs + auto comparisons = mem_ptr->tmp_lwe_array_out; + scalar_compare_radix_blocks_kb(streams, gpu_indexes, gpu_count, + comparisons, lhs, rhs, mem_ptr, + bsks, ksks, num_lsb_radix_blocks); + + // Reduces a vec containing radix blocks that encrypts a sign + // (inferior, equal, superior) to one single radix block containing the + // final sign + tree_sign_reduction(streams, gpu_indexes, gpu_count, lwe_array_out, + comparisons, mem_ptr->diff_buffer->tree_buffer, + sign_handler_f, bsks, ksks, + num_lsb_radix_blocks); + } } } @@ -448,65 +506,104 @@ __host__ void integer_radix_signed_scalar_difference_check_kb( 2); } else { - // We only have to do the regular comparison - // And not the part where we compare most significant blocks with zeros - // total_num_radix_blocks == total_num_scalar_blocks - uint32_t num_lsb_radix_blocks = total_num_radix_blocks; - - for (uint j = 0; j < gpu_count; j++) { - cuda_synchronize_stream(streams[j], gpu_indexes[j]); - } - auto lsb_streams = mem_ptr->lsb_streams; - auto msb_streams = mem_ptr->msb_streams; - - auto lwe_array_ct_out = mem_ptr->tmp_lwe_array_out; - auto lwe_array_sign_out = - lwe_array_ct_out + (num_lsb_radix_blocks / 2) * big_lwe_size; - Torus *lhs = diff_buffer->tmp_packed; - Torus *rhs = - diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size; - - pack_blocks(lsb_streams[0], gpu_indexes[0], lhs, lwe_array_in, - big_lwe_dimension, num_lsb_radix_blocks - 1, - message_modulus); - pack_blocks(lsb_streams[0], gpu_indexes[0], rhs, scalar_blocks, 0, - num_lsb_radix_blocks - 1, message_modulus); - - // From this point we have half number of blocks - num_lsb_radix_blocks /= 2; - - // comparisons will be assigned - // - 0 if lhs < rhs - // - 1 if lhs == rhs - // - 2 if lhs > rhs - scalar_compare_radix_blocks_kb(lsb_streams, gpu_indexes, gpu_count, - lwe_array_ct_out, lhs, rhs, mem_ptr, - bsks, ksks, num_lsb_radix_blocks); - Torus const *encrypted_sign_block = - lwe_array_in + (total_num_radix_blocks - 1) * big_lwe_size; - Torus const *scalar_sign_block = - scalar_blocks + (total_num_scalar_blocks - 1); - - auto trivial_sign_block = mem_ptr->tmp_trivial_sign_block; - create_trivial_radix( - msb_streams[0], gpu_indexes[0], trivial_sign_block, scalar_sign_block, - big_lwe_dimension, 1, 1, message_modulus, carry_modulus); + if (total_num_radix_blocks == 1) { + std::pair invert_flags = get_invert_flags(mem_ptr->op); + Torus scalar = 0; + cuda_memcpy_async_to_cpu(&scalar, scalar_blocks, sizeof(Torus), + streams[0], gpu_indexes[0]); + cuda_synchronize_stream(streams[0], gpu_indexes[0]); + auto one_block_lut_f = [invert_flags, scalar, + message_modulus](Torus x) -> Torus { + Torus x_0; + Torus x_1; + if (invert_flags.first) { + x_0 = scalar; + x_1 = x; + } else { + x_0 = x; + x_1 = scalar; + } + return (Torus)(invert_flags.second) ^ + is_x_less_than_y_given_input_borrow(x_0, x_1, 0, + message_modulus); + }; + int_radix_lut *one_block_lut = new int_radix_lut( + streams, gpu_indexes, gpu_count, params, 1, 1, true); + + generate_device_accumulator( + streams[0], gpu_indexes[0], one_block_lut->get_lut(0, 0), + one_block_lut->get_degree(0), one_block_lut->get_max_degree(0), + params.glwe_dimension, params.polynomial_size, params.message_modulus, + params.carry_modulus, one_block_lut_f); + + one_block_lut->broadcast_lut(streams, gpu_indexes, 0); + + legacy_integer_radix_apply_univariate_lookup_table_kb( + streams, gpu_indexes, gpu_count, lwe_array_out, lwe_array_in, bsks, + ksks, 1, one_block_lut); + one_block_lut->release(streams, gpu_indexes, gpu_count); + delete one_block_lut; + } else { + // We only have to do the regular comparison + // And not the part where we compare most significant blocks with zeros + // total_num_radix_blocks == total_num_scalar_blocks + uint32_t num_lsb_radix_blocks = total_num_radix_blocks; + + for (uint j = 0; j < gpu_count; j++) { + cuda_synchronize_stream(streams[j], gpu_indexes[j]); + } + auto lsb_streams = mem_ptr->lsb_streams; + auto msb_streams = mem_ptr->msb_streams; + + auto lwe_array_ct_out = mem_ptr->tmp_lwe_array_out; + auto lwe_array_sign_out = + lwe_array_ct_out + (num_lsb_radix_blocks / 2) * big_lwe_size; + Torus *lhs = diff_buffer->tmp_packed; + Torus *rhs = + diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size; + + pack_blocks(lsb_streams[0], gpu_indexes[0], lhs, lwe_array_in, + big_lwe_dimension, num_lsb_radix_blocks - 1, + message_modulus); + pack_blocks(lsb_streams[0], gpu_indexes[0], rhs, scalar_blocks, 0, + num_lsb_radix_blocks - 1, message_modulus); + + // From this point we have half number of blocks + num_lsb_radix_blocks /= 2; + + // comparisons will be assigned + // - 0 if lhs < rhs + // - 1 if lhs == rhs + // - 2 if lhs > rhs + scalar_compare_radix_blocks_kb(lsb_streams, gpu_indexes, gpu_count, + lwe_array_ct_out, lhs, rhs, mem_ptr, + bsks, ksks, num_lsb_radix_blocks); + Torus const *encrypted_sign_block = + lwe_array_in + (total_num_radix_blocks - 1) * big_lwe_size; + Torus const *scalar_sign_block = + scalar_blocks + (total_num_scalar_blocks - 1); + + auto trivial_sign_block = mem_ptr->tmp_trivial_sign_block; + create_trivial_radix( + msb_streams[0], gpu_indexes[0], trivial_sign_block, scalar_sign_block, + big_lwe_dimension, 1, 1, message_modulus, carry_modulus); + + legacy_integer_radix_apply_bivariate_lookup_table_kb( + msb_streams, gpu_indexes, gpu_count, lwe_array_sign_out, + encrypted_sign_block, trivial_sign_block, bsks, ksks, 1, + mem_ptr->signed_lut, mem_ptr->signed_lut->params.message_modulus); + for (uint j = 0; j < mem_ptr->active_gpu_count; j++) { + cuda_synchronize_stream(lsb_streams[j], gpu_indexes[j]); + cuda_synchronize_stream(msb_streams[j], gpu_indexes[j]); + } - legacy_integer_radix_apply_bivariate_lookup_table_kb( - msb_streams, gpu_indexes, gpu_count, lwe_array_sign_out, - encrypted_sign_block, trivial_sign_block, bsks, ksks, 1, - mem_ptr->signed_lut, mem_ptr->signed_lut->params.message_modulus); - for (uint j = 0; j < mem_ptr->active_gpu_count; j++) { - cuda_synchronize_stream(lsb_streams[j], gpu_indexes[j]); - cuda_synchronize_stream(msb_streams[j], gpu_indexes[j]); + // Reduces a vec containing radix blocks that encrypts a sign + // (inferior, equal, superior) to one single radix block containing the + // final sign + reduce_signs(streams, gpu_indexes, gpu_count, lwe_array_out, + lwe_array_ct_out, mem_ptr, sign_handler_f, bsks, ksks, + num_lsb_radix_blocks + 1); } - - // Reduces a vec containing radix blocks that encrypts a sign - // (inferior, equal, superior) to one single radix block containing the - // final sign - reduce_signs(streams, gpu_indexes, gpu_count, lwe_array_out, - lwe_array_ct_out, mem_ptr, sign_handler_f, bsks, ksks, - num_lsb_radix_blocks + 1); } } diff --git a/tfhe/src/high_level_api/booleans/base.rs b/tfhe/src/high_level_api/booleans/base.rs index 207244afda..83fec2bcb9 100644 --- a/tfhe/src/high_level_api/booleans/base.rs +++ b/tfhe/src/high_level_api/booleans/base.rs @@ -245,9 +245,16 @@ impl IfThenElse> for FheBool { FheInt::new(new_ct, key.tag.clone()) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support signed integers") - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let inner = cuda_key.key.key.if_then_else( + &CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)), + &*ct_then.ciphertext.on_gpu(streams), + &*ct_else.ciphertext.on_gpu(streams), + streams, + ); + + FheInt::new(inner, cuda_key.tag.clone()) + }), }) } } diff --git a/tfhe/src/high_level_api/integers/signed/encrypt.rs b/tfhe/src/high_level_api/integers/signed/encrypt.rs index 9a75e5e1aa..2bb1546e40 100644 --- a/tfhe/src/high_level_api/integers/signed/encrypt.rs +++ b/tfhe/src/high_level_api/integers/signed/encrypt.rs @@ -1,8 +1,12 @@ use crate::core_crypto::prelude::SignedNumeric; use crate::high_level_api::global_state; +use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::FheIntId; +use crate::high_level_api::keys::InternalServerKey; use crate::integer::block_decomposition::DecomposableInto; use crate::integer::client_key::RecomposableSignedInteger; +#[cfg(feature = "gpu")] +use crate::integer::gpu::ciphertext::CudaSignedRadixCiphertext; use crate::prelude::{FheDecrypt, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt}; use crate::{ClientKey, CompressedPublicKey, FheInt, PublicKey}; @@ -101,14 +105,22 @@ where /// Trivial encryptions become real encrypted data once used in an operation /// that involves a real ciphertext fn try_encrypt_trivial(value: T) -> Result { - global_state::with_cpu_internal_keys(|sks| { - let ciphertext = sks - .pbs_key() - .create_trivial_radix::( + global_state::with_internal_keys(|key| match key { + InternalServerKey::Cpu(key) => { + let ciphertext: crate::integer::SignedRadixCiphertext = key + .pbs_key() + .create_trivial_radix(value, Id::num_blocks(key.message_modulus())); + Ok(Self::new(ciphertext, key.tag.clone())) + } + #[cfg(feature = "gpu")] + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let inner: CudaSignedRadixCiphertext = cuda_key.key.key.create_trivial_radix( value, - Id::num_blocks(sks.message_modulus()), + Id::num_blocks(cuda_key.key.key.message_modulus), + streams, ); - Ok(Self::new(ciphertext, sks.tag.clone())) + Ok(Self::new(inner, cuda_key.tag.clone())) + }), }) } } diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_comparison.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_comparison.rs index 2d1ca64bd3..158b0e5dff 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_comparison.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_scalar_comparison.rs @@ -226,6 +226,106 @@ where } } +fn integer_unchecked_scalar_comparisons_edge_one_block

(param: P) +where + P: Into, +{ + let p = param.into(); + let num_block = 1; + + let stream = CudaStreams::new_multi_gpu(); + + let (cks, sks) = gen_keys_gpu(p, &stream); + let message_modulus = cks.parameters().message_modulus().0; + + let mut rng = rand::thread_rng(); + + for _ in 0..4 { + let clear_a = rng.gen_range(0..message_modulus); + let clear_b = rng.gen_range(0..message_modulus); + + let a = cks.encrypt_radix(clear_a, num_block); + // Copy to the GPU + let d_a = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&a, &stream); + + // >= + { + let d_result = sks.unchecked_scalar_ge(&d_a, clear_b, &stream); + let result = d_result.to_boolean_block(&stream); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a >= clear_b); + } + + // > + { + let d_result = sks.unchecked_scalar_gt(&d_a, clear_b, &stream); + let result = d_result.to_boolean_block(&stream); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a > clear_b); + } + + // <= + { + let d_result = sks.unchecked_scalar_le(&d_a, clear_b, &stream); + let result = d_result.to_boolean_block(&stream); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a <= clear_b); + } + + // < + { + let d_result = sks.unchecked_scalar_lt(&d_a, clear_b, &stream); + let result = d_result.to_boolean_block(&stream); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a < clear_b); + } + + // == + { + let d_result = sks.unchecked_scalar_eq(&d_a, clear_b, &stream); + let result = d_result.to_boolean_block(&stream); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a == clear_b); + } + + // != + { + let d_result = sks.unchecked_scalar_ne(&d_a, clear_b, &stream); + let result = d_result.to_boolean_block(&stream); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a != clear_b); + } + + // Here the goal is to test, the branching + // made in the scalar sign function + // + // We are forcing one of the two branches to work on empty slices + { + let d_result = sks.unchecked_scalar_lt(&d_a, 0, &stream); + let result = d_result.to_boolean_block(&stream); + let decrypted = cks.decrypt_bool(&result); + assert!(!decrypted); + + let d_result = sks.unchecked_scalar_lt(&d_a, message_modulus, &stream); + let result = d_result.to_boolean_block(&stream); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a < message_modulus); + + // == (as it does not share same code) + let d_result = sks.unchecked_scalar_eq(&d_a, 0, &stream); + let result = d_result.to_boolean_block(&stream); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a == 0); + + // != (as it does not share same code) + let d_result = sks.unchecked_scalar_ne(&d_a, message_modulus, &stream); + let result = d_result.to_boolean_block(&stream); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a != message_modulus); + } + } +} + create_gpu_parameterized_test!(integer_unchecked_scalar_min_u256 { // TODO GPU DRIFT UPDATE PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, @@ -264,3 +364,9 @@ create_gpu_parameterized_test!(integer_unchecked_scalar_comparisons_edge { PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, V1_0_PARAM_GPU_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, }); +create_gpu_parameterized_test!(integer_unchecked_scalar_comparisons_edge_one_block { + // TODO GPU DRIFT UPDATE + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, + PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, + V1_0_PARAM_GPU_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, +}); diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_comparison.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_comparison.rs index ea5be6f09c..70775ad084 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_comparison.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_comparison.rs @@ -402,6 +402,88 @@ fn integer_unchecked_scalar_comparisons_edge(param: ClassicPBSParameters) { } } +fn integer_unchecked_scalar_comparisons_edge_one_block(param: ClassicPBSParameters) { + let mut rng = rand::thread_rng(); + + let num_block = 1; + + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let message_modulus = cks.parameters().message_modulus().0; + + for _ in 0..4 { + let clear_a = rng.gen_range(0..message_modulus); + let clear_b = rng.gen_range(0..message_modulus); + + let a = cks.encrypt_radix(clear_a, num_block); + + // >= + { + let result = sks.unchecked_scalar_ge_parallelized(&a, clear_b); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a >= clear_b); + } + + // > + { + let result = sks.unchecked_scalar_gt_parallelized(&a, clear_b); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a > clear_b); + } + + // <= + { + let result = sks.unchecked_scalar_le_parallelized(&a, clear_b); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a <= clear_b); + } + + // < + { + let result = sks.unchecked_scalar_lt_parallelized(&a, clear_b); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a < clear_b); + } + + // == + { + let result = sks.unchecked_scalar_eq_parallelized(&a, clear_b); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a == clear_b); + } + + // != + { + let result = sks.unchecked_scalar_ne_parallelized(&a, clear_b); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a != clear_b); + } + + // Here the goal is to test, the branching + // made in the scalar sign function + // + // We are forcing one of the two branches to work on empty slices + { + let result = sks.unchecked_scalar_lt_parallelized(&a, 0); + let decrypted = cks.decrypt_bool(&result); + assert!(!decrypted); + + let result = sks.unchecked_scalar_lt_parallelized(&a, message_modulus); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a < message_modulus); + + // == (as it does not share same code) + let result = sks.unchecked_scalar_eq_parallelized(&a, 0); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a == 0); + + // != (as it does not share same code) + let result = sks.unchecked_scalar_ne_parallelized(&a, message_modulus); + let decrypted = cks.decrypt_bool(&result); + assert_eq!(decrypted, clear_a != message_modulus); + } + } +} + // Given a ciphertext that consists of empty blocks, // the function tests whether comparisons still hold. fn integer_comparisons_for_empty_blocks(param: ClassicPBSParameters) { @@ -775,6 +857,14 @@ mod no_coverage { V1_0_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64 }); + create_parameterized_test!(integer_unchecked_scalar_comparisons_edge_one_block { + V1_0_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128, + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + V1_0_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128, + // 2M128 is too slow for 4_4, it is estimated to be 2x slower + V1_0_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64 + }); + create_parameterized_test!(integer_is_scalar_out_of_bounds { PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, // We don't use PARAM_MESSAGE_3_CARRY_3_KS_PBS,