Skip to content

Commit

Permalink
fix(gpu): fix default pbs with many luts
Browse files Browse the repository at this point in the history
  • Loading branch information
guillermo-oyarzun committed Oct 14, 2024
1 parent ff0609f commit 748ec04
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@ __global__ void __launch_bounds__(params::degree / params::opt)
for (int i = 1; i < lut_count; i++) {
auto next_lwe_array_out =
lwe_array_out +
(i * gridDim.z * (glwe_dimension * polynomial_size + 1));
(i * gridDim.x * (glwe_dimension * polynomial_size + 1));
auto next_block_lwe_array_out =
&next_lwe_array_out[lwe_output_indexes[blockIdx.z] *
&next_lwe_array_out[lwe_output_indexes[blockIdx.x] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];

Expand All @@ -239,9 +239,9 @@ __global__ void __launch_bounds__(params::degree / params::opt)

auto next_lwe_array_out =
lwe_array_out +
(i * gridDim.z * (glwe_dimension * polynomial_size + 1));
(i * gridDim.x * (glwe_dimension * polynomial_size + 1));
auto next_block_lwe_array_out =
&next_lwe_array_out[lwe_output_indexes[blockIdx.z] *
&next_lwe_array_out[lwe_output_indexes[blockIdx.x] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,9 @@ __global__ void __launch_bounds__(params::degree / params::opt)
for (int i = 1; i < lut_count; i++) {
auto next_lwe_array_out =
lwe_array_out +
(i * gridDim.z * (glwe_dimension * polynomial_size + 1));
(i * gridDim.x * (glwe_dimension * polynomial_size + 1));
auto next_block_lwe_array_out =
&next_lwe_array_out[lwe_output_indexes[blockIdx.z] *
&next_lwe_array_out[lwe_output_indexes[blockIdx.x] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];

Expand All @@ -348,9 +348,9 @@ __global__ void __launch_bounds__(params::degree / params::opt)

auto next_lwe_array_out =
lwe_array_out +
(i * gridDim.z * (glwe_dimension * polynomial_size + 1));
(i * gridDim.x * (glwe_dimension * polynomial_size + 1));
auto next_block_lwe_array_out =
&next_lwe_array_out[lwe_output_indexes[blockIdx.z] *
&next_lwe_array_out[lwe_output_indexes[blockIdx.x] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];

Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/integer/gpu/server_key/radix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -976,8 +976,8 @@ impl CudaServerKey {
/// // Generate the lookup table for the functions
/// // f1: x -> x*x mod 4
/// // f2: x -> count_ones(x as binary) mod 4
/// let f1 = |x: u64| x.pow(2) % 4;
/// let f2 = |x: u64| x.count_ones() as u64 % 4;
/// let f1 = |x: u64| x.pow(2) % 8;
/// let f2 = |x: u64| x.count_ones() as u64 % 8;
/// // Easy to use for generation
/// let luts = sks.generate_many_lookup_table(&[&f1, &f2]);
/// let vec_res = unsafe { sks.apply_many_lookup_table_async(&d_ct.as_ref(), &luts, &stream) };
Expand Down

0 comments on commit 748ec04

Please sign in to comment.