Skip to content

Commit

Permalink
chore(gpu): fix GPU PBS benchmark parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Feb 20, 2024
1 parent d55d68e commit 41c38d1
Showing 1 changed file with 49 additions and 3 deletions.
52 changes: 49 additions & 3 deletions tfhe/benches/core_crypto/pbs_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,8 @@ fn pbs_throughput<Scalar: UnsignedTorus + CastInto<usize> + Sync + Send + Serial

#[cfg(feature = "gpu")]
mod cuda {
use super::{benchmark_parameters, multi_bit_benchmark_parameters};
use crate::utilities::{write_to_json, OperatorType};
use super::multi_bit_benchmark_parameters;
use crate::utilities::{write_to_json, CryptoParametersRecord, OperatorType};
use criterion::{black_box, criterion_group, Criterion};
use serde::Serialize;
use tfhe::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
Expand All @@ -551,6 +551,52 @@ mod cuda {
cuda_programmable_bootstrap_lwe_ciphertext, CudaDevice, CudaStream,
};
use tfhe::core_crypto::prelude::*;
use tfhe::keycache::NamedParam;
use tfhe::shortint::parameters::{
PARAM_MESSAGE_1_CARRY_0_KS_PBS, PARAM_MESSAGE_1_CARRY_1_KS_PBS,
PARAM_MESSAGE_2_CARRY_0_KS_PBS, PARAM_MESSAGE_2_CARRY_1_KS_PBS,
PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_0_KS_PBS,
PARAM_MESSAGE_3_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS,
PARAM_MESSAGE_4_CARRY_0_KS_PBS, PARAM_MESSAGE_4_CARRY_3_KS_PBS,
PARAM_MESSAGE_5_CARRY_0_KS_PBS, PARAM_MESSAGE_6_CARRY_0_KS_PBS,
PARAM_MESSAGE_7_CARRY_0_KS_PBS,
};
use tfhe::shortint::{ClassicPBSParameters, PBSParameters};

const SHORTINT_CUDA_BENCH_PARAMS: [ClassicPBSParameters; 13] = [
PARAM_MESSAGE_1_CARRY_0_KS_PBS,
PARAM_MESSAGE_1_CARRY_1_KS_PBS,
PARAM_MESSAGE_2_CARRY_0_KS_PBS,
PARAM_MESSAGE_2_CARRY_1_KS_PBS,
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
PARAM_MESSAGE_3_CARRY_0_KS_PBS,
PARAM_MESSAGE_3_CARRY_2_KS_PBS,
PARAM_MESSAGE_3_CARRY_3_KS_PBS,
PARAM_MESSAGE_4_CARRY_0_KS_PBS,
PARAM_MESSAGE_4_CARRY_3_KS_PBS,
PARAM_MESSAGE_5_CARRY_0_KS_PBS,
PARAM_MESSAGE_6_CARRY_0_KS_PBS,
PARAM_MESSAGE_7_CARRY_0_KS_PBS,
];

fn cuda_benchmark_parameters<Scalar: UnsignedInteger>(
) -> Vec<(String, CryptoParametersRecord<Scalar>)> {
if Scalar::BITS == 64 {
SHORTINT_CUDA_BENCH_PARAMS
.iter()
.map(|params| {
(
params.name(),
<ClassicPBSParameters as Into<PBSParameters>>::into(*params)
.to_owned()
.into(),
)
})
.collect()
} else {
vec![]
}
}

fn cuda_pbs<Scalar: UnsignedTorus + CastInto<usize> + Serialize>(c: &mut Criterion) {
let bench_name = "core_crypto::cuda::pbs";
Expand All @@ -568,7 +614,7 @@ mod cuda {
let device = CudaDevice::new(gpu_index);
let stream = CudaStream::new_unchecked(device);

for (name, params) in benchmark_parameters::<Scalar>().iter() {
for (name, params) in cuda_benchmark_parameters::<Scalar>().iter() {
// Create the LweSecretKey
let input_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
params.lwe_dimension.unwrap(),
Expand Down

0 comments on commit 41c38d1

Please sign in to comment.