Skip to content

Commit

Permalink
fix(gpu): fix scalar comparisons with 1 block
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Feb 24, 2025
1 parent 219c755 commit 7305c18
Show file tree
Hide file tree
Showing 7 changed files with 448 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4728,4 +4728,5 @@ void update_degrees_after_scalar_bitxor(uint64_t *output_degrees,
uint64_t *clear_degrees,
uint64_t *input_degrees,
uint32_t num_clear_blocks);
std::pair<bool, bool> get_invert_flags(COMPARISON_TYPE compare);
#endif // CUDA_INTEGER_UTILITIES_H
35 changes: 33 additions & 2 deletions backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,36 @@
#include "integer/scalar_comparison.cuh"

#include <iostream>
#include <utility> // for std::pair

std::pair<bool, bool> get_invert_flags(COMPARISON_TYPE compare) {
bool invert_operands;
bool invert_subtraction_result;

switch (compare) {
case COMPARISON_TYPE::LT:
invert_operands = false;
invert_subtraction_result = false;
break;
case COMPARISON_TYPE::LE:
invert_operands = true;
invert_subtraction_result = true;
break;
case COMPARISON_TYPE::GT:
invert_operands = true;
invert_subtraction_result = false;
break;
case COMPARISON_TYPE::GE:
invert_operands = false;
invert_subtraction_result = true;
break;
default:
PANIC("Cuda error: invalid comparison type")
}

return {invert_operands, invert_subtraction_result};
}

void cuda_scalar_comparison_integer_radix_ciphertext_kb_64(
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
void *lwe_array_out, void const *lwe_array_in, void const *scalar_blocks,
Expand All @@ -22,9 +53,9 @@ void cuda_scalar_comparison_integer_radix_ciphertext_kb_64(
case GE:
case LT:
case LE:
if (lwe_ciphertext_count % 2 != 0)
if (lwe_ciphertext_count % 2 != 0 && lwe_ciphertext_count != 1)
PANIC("Cuda error (scalar comparisons): the number of radix blocks has "
"to be even.")
"to be even or equal to 1.")
host_integer_radix_scalar_difference_check_kb<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(lwe_array_out),
Expand Down
283 changes: 190 additions & 93 deletions backends/tfhe-cuda-backend/cuda/src/integer/scalar_comparison.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,27 @@
#define CUDA_INTEGER_SCALAR_COMPARISON_OPS_CUH

#include "integer/comparison.cuh"
template <typename Torus>
Torus is_x_less_than_y_given_input_borrow(Torus last_x_block,
Torus last_y_block, Torus borrow,
uint32_t message_modulus) {
Torus last_bit_pos = log2_int(message_modulus) - 1;
Torus mask = (1 << last_bit_pos) - 1;
Torus x_without_last_bit = last_x_block & mask;
Torus y_without_last_bit = last_y_block & mask;

bool input_borrow_to_last_bit =
x_without_last_bit < (y_without_last_bit + borrow);

Torus result = last_x_block - (last_y_block + borrow);

Torus output_sign_bit = (result >> last_bit_pos) & 1;
bool output_borrow = last_x_block < (last_y_block + borrow);

Torus overflow_flag = (Torus)(input_borrow_to_last_bit ^ output_borrow);

return output_sign_bit ^ overflow_flag;
}

template <typename Torus>
__host__ void scalar_compare_radix_blocks_kb(
Expand Down Expand Up @@ -205,42 +226,79 @@ __host__ void integer_radix_unsigned_scalar_difference_check_kb(
lwe_array_msb_out, bsks, ksks, 1, lut, lut->params.message_modulus);

} else {
// We only have to do the regular comparison
// And not the part where we compare most significant blocks with zeros
// total_num_radix_blocks == total_num_scalar_blocks
uint32_t num_lsb_radix_blocks = total_num_radix_blocks;
uint32_t num_scalar_blocks = total_num_scalar_blocks;

Torus *lhs = diff_buffer->tmp_packed;
Torus *rhs =
diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size;

pack_blocks<Torus>(streams[0], gpu_indexes[0], lhs, lwe_array_in,
big_lwe_dimension, num_lsb_radix_blocks,
message_modulus);
pack_blocks<Torus>(streams[0], gpu_indexes[0], rhs, scalar_blocks, 0,
num_scalar_blocks, message_modulus);

// From this point we have half number of blocks
num_lsb_radix_blocks /= 2;
num_scalar_blocks /= 2;

// comparisons will be assigned
// - 0 if lhs < rhs
// - 1 if lhs == rhs
// - 2 if lhs > rhs
auto comparisons = mem_ptr->tmp_lwe_array_out;
scalar_compare_radix_blocks_kb<Torus>(streams, gpu_indexes, gpu_count,
comparisons, lhs, rhs, mem_ptr, bsks,
ksks, num_lsb_radix_blocks);

// Reduces a vec containing radix blocks that encrypts a sign
// (inferior, equal, superior) to one single radix block containing the
// final sign
tree_sign_reduction<Torus>(streams, gpu_indexes, gpu_count, lwe_array_out,
comparisons, mem_ptr->diff_buffer->tree_buffer,
sign_handler_f, bsks, ksks,
num_lsb_radix_blocks);
if (total_num_radix_blocks == 1) {
std::pair<bool, bool> invert_flags = get_invert_flags(mem_ptr->op);
Torus scalar = 0;
cuda_memcpy_async_to_cpu(&scalar, scalar_blocks, sizeof(Torus),
streams[0], gpu_indexes[0]);
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
auto one_block_lut_f = [invert_flags, scalar](Torus x) -> Torus {
Torus x_0;
Torus x_1;
if (invert_flags.first) {
x_0 = scalar;
x_1 = x;
} else {
x_0 = x;
x_1 = scalar;
}
auto overflowed = x_0 < x_1;
return (Torus)(invert_flags.second ^ overflowed);
};
int_radix_lut<Torus> *one_block_lut = new int_radix_lut<Torus>(
streams, gpu_indexes, gpu_count, params, 1, 1, true);

generate_device_accumulator<Torus>(
streams[0], gpu_indexes[0], one_block_lut->get_lut(0, 0),
one_block_lut->get_degree(0), one_block_lut->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, one_block_lut_f);

one_block_lut->broadcast_lut(streams, gpu_indexes, 0);

legacy_integer_radix_apply_univariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, lwe_array_out, lwe_array_in, bsks,
ksks, 1, one_block_lut);
one_block_lut->release(streams, gpu_indexes, gpu_count);
delete one_block_lut;
} else {
// We only have to do the regular comparison
// And not the part where we compare most significant blocks with zeros
// total_num_radix_blocks == total_num_scalar_blocks
uint32_t num_lsb_radix_blocks = total_num_radix_blocks;
uint32_t num_scalar_blocks = total_num_scalar_blocks;

Torus *lhs = diff_buffer->tmp_packed;
Torus *rhs =
diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size;

pack_blocks<Torus>(streams[0], gpu_indexes[0], lhs, lwe_array_in,
big_lwe_dimension, num_lsb_radix_blocks,
message_modulus);
pack_blocks<Torus>(streams[0], gpu_indexes[0], rhs, scalar_blocks, 0,
num_scalar_blocks, message_modulus);

// From this point we have half number of blocks
num_lsb_radix_blocks /= 2;
num_scalar_blocks /= 2;

// comparisons will be assigned
// - 0 if lhs < rhs
// - 1 if lhs == rhs
// - 2 if lhs > rhs
auto comparisons = mem_ptr->tmp_lwe_array_out;
scalar_compare_radix_blocks_kb<Torus>(streams, gpu_indexes, gpu_count,
comparisons, lhs, rhs, mem_ptr,
bsks, ksks, num_lsb_radix_blocks);

// Reduces a vec containing radix blocks that encrypts a sign
// (inferior, equal, superior) to one single radix block containing the
// final sign
tree_sign_reduction<Torus>(streams, gpu_indexes, gpu_count, lwe_array_out,
comparisons, mem_ptr->diff_buffer->tree_buffer,
sign_handler_f, bsks, ksks,
num_lsb_radix_blocks);
}
}
}

Expand Down Expand Up @@ -448,65 +506,104 @@ __host__ void integer_radix_signed_scalar_difference_check_kb(
2);

} else {
// We only have to do the regular comparison
// And not the part where we compare most significant blocks with zeros
// total_num_radix_blocks == total_num_scalar_blocks
uint32_t num_lsb_radix_blocks = total_num_radix_blocks;

for (uint j = 0; j < gpu_count; j++) {
cuda_synchronize_stream(streams[j], gpu_indexes[j]);
}
auto lsb_streams = mem_ptr->lsb_streams;
auto msb_streams = mem_ptr->msb_streams;

auto lwe_array_ct_out = mem_ptr->tmp_lwe_array_out;
auto lwe_array_sign_out =
lwe_array_ct_out + (num_lsb_radix_blocks / 2) * big_lwe_size;
Torus *lhs = diff_buffer->tmp_packed;
Torus *rhs =
diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size;

pack_blocks<Torus>(lsb_streams[0], gpu_indexes[0], lhs, lwe_array_in,
big_lwe_dimension, num_lsb_radix_blocks - 1,
message_modulus);
pack_blocks<Torus>(lsb_streams[0], gpu_indexes[0], rhs, scalar_blocks, 0,
num_lsb_radix_blocks - 1, message_modulus);

// From this point we have half number of blocks
num_lsb_radix_blocks /= 2;

// comparisons will be assigned
// - 0 if lhs < rhs
// - 1 if lhs == rhs
// - 2 if lhs > rhs
scalar_compare_radix_blocks_kb<Torus>(lsb_streams, gpu_indexes, gpu_count,
lwe_array_ct_out, lhs, rhs, mem_ptr,
bsks, ksks, num_lsb_radix_blocks);
Torus const *encrypted_sign_block =
lwe_array_in + (total_num_radix_blocks - 1) * big_lwe_size;
Torus const *scalar_sign_block =
scalar_blocks + (total_num_scalar_blocks - 1);

auto trivial_sign_block = mem_ptr->tmp_trivial_sign_block;
create_trivial_radix<Torus>(
msb_streams[0], gpu_indexes[0], trivial_sign_block, scalar_sign_block,
big_lwe_dimension, 1, 1, message_modulus, carry_modulus);
if (total_num_radix_blocks == 1) {
std::pair<bool, bool> invert_flags = get_invert_flags(mem_ptr->op);
Torus scalar = 0;
cuda_memcpy_async_to_cpu(&scalar, scalar_blocks, sizeof(Torus),
streams[0], gpu_indexes[0]);
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
auto one_block_lut_f = [invert_flags, scalar,
message_modulus](Torus x) -> Torus {
Torus x_0;
Torus x_1;
if (invert_flags.first) {
x_0 = scalar;
x_1 = x;
} else {
x_0 = x;
x_1 = scalar;
}
return (Torus)(invert_flags.second) ^
is_x_less_than_y_given_input_borrow<Torus>(x_0, x_1, 0,
message_modulus);
};
int_radix_lut<Torus> *one_block_lut = new int_radix_lut<Torus>(
streams, gpu_indexes, gpu_count, params, 1, 1, true);

generate_device_accumulator<Torus>(
streams[0], gpu_indexes[0], one_block_lut->get_lut(0, 0),
one_block_lut->get_degree(0), one_block_lut->get_max_degree(0),
params.glwe_dimension, params.polynomial_size, params.message_modulus,
params.carry_modulus, one_block_lut_f);

one_block_lut->broadcast_lut(streams, gpu_indexes, 0);

legacy_integer_radix_apply_univariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, lwe_array_out, lwe_array_in, bsks,
ksks, 1, one_block_lut);
one_block_lut->release(streams, gpu_indexes, gpu_count);
delete one_block_lut;
} else {
// We only have to do the regular comparison
// And not the part where we compare most significant blocks with zeros
// total_num_radix_blocks == total_num_scalar_blocks
uint32_t num_lsb_radix_blocks = total_num_radix_blocks;

for (uint j = 0; j < gpu_count; j++) {
cuda_synchronize_stream(streams[j], gpu_indexes[j]);
}
auto lsb_streams = mem_ptr->lsb_streams;
auto msb_streams = mem_ptr->msb_streams;

auto lwe_array_ct_out = mem_ptr->tmp_lwe_array_out;
auto lwe_array_sign_out =
lwe_array_ct_out + (num_lsb_radix_blocks / 2) * big_lwe_size;
Torus *lhs = diff_buffer->tmp_packed;
Torus *rhs =
diff_buffer->tmp_packed + total_num_radix_blocks / 2 * big_lwe_size;

pack_blocks<Torus>(lsb_streams[0], gpu_indexes[0], lhs, lwe_array_in,
big_lwe_dimension, num_lsb_radix_blocks - 1,
message_modulus);
pack_blocks<Torus>(lsb_streams[0], gpu_indexes[0], rhs, scalar_blocks, 0,
num_lsb_radix_blocks - 1, message_modulus);

// From this point we have half number of blocks
num_lsb_radix_blocks /= 2;

// comparisons will be assigned
// - 0 if lhs < rhs
// - 1 if lhs == rhs
// - 2 if lhs > rhs
scalar_compare_radix_blocks_kb<Torus>(lsb_streams, gpu_indexes, gpu_count,
lwe_array_ct_out, lhs, rhs, mem_ptr,
bsks, ksks, num_lsb_radix_blocks);
Torus const *encrypted_sign_block =
lwe_array_in + (total_num_radix_blocks - 1) * big_lwe_size;
Torus const *scalar_sign_block =
scalar_blocks + (total_num_scalar_blocks - 1);

auto trivial_sign_block = mem_ptr->tmp_trivial_sign_block;
create_trivial_radix<Torus>(
msb_streams[0], gpu_indexes[0], trivial_sign_block, scalar_sign_block,
big_lwe_dimension, 1, 1, message_modulus, carry_modulus);

legacy_integer_radix_apply_bivariate_lookup_table_kb<Torus>(
msb_streams, gpu_indexes, gpu_count, lwe_array_sign_out,
encrypted_sign_block, trivial_sign_block, bsks, ksks, 1,
mem_ptr->signed_lut, mem_ptr->signed_lut->params.message_modulus);
for (uint j = 0; j < mem_ptr->active_gpu_count; j++) {
cuda_synchronize_stream(lsb_streams[j], gpu_indexes[j]);
cuda_synchronize_stream(msb_streams[j], gpu_indexes[j]);
}

legacy_integer_radix_apply_bivariate_lookup_table_kb<Torus>(
msb_streams, gpu_indexes, gpu_count, lwe_array_sign_out,
encrypted_sign_block, trivial_sign_block, bsks, ksks, 1,
mem_ptr->signed_lut, mem_ptr->signed_lut->params.message_modulus);
for (uint j = 0; j < mem_ptr->active_gpu_count; j++) {
cuda_synchronize_stream(lsb_streams[j], gpu_indexes[j]);
cuda_synchronize_stream(msb_streams[j], gpu_indexes[j]);
// Reduces a vec containing radix blocks that encrypts a sign
// (inferior, equal, superior) to one single radix block containing the
// final sign
reduce_signs<Torus>(streams, gpu_indexes, gpu_count, lwe_array_out,
lwe_array_ct_out, mem_ptr, sign_handler_f, bsks, ksks,
num_lsb_radix_blocks + 1);
}

// Reduces a vec containing radix blocks that encrypts a sign
// (inferior, equal, superior) to one single radix block containing the
// final sign
reduce_signs<Torus>(streams, gpu_indexes, gpu_count, lwe_array_out,
lwe_array_ct_out, mem_ptr, sign_handler_f, bsks, ksks,
num_lsb_radix_blocks + 1);
}
}

Expand Down
13 changes: 10 additions & 3 deletions tfhe/src/high_level_api/booleans/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,16 @@ impl<Id: FheIntId> IfThenElse<FheInt<Id>> for FheBool {
FheInt::new(new_ct, key.tag.clone())
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support signed integers")
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let inner = cuda_key.key.key.if_then_else(
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
&*ct_then.ciphertext.on_gpu(streams),
&*ct_else.ciphertext.on_gpu(streams),
streams,
);

FheInt::new(inner, cuda_key.tag.clone())
}),
})
}
}
Expand Down
Loading

0 comments on commit 7305c18

Please sign in to comment.