From 32c93876d7d00f598d4e42a2ff0b648dd2e3d2b2 Mon Sep 17 00:00:00 2001 From: Guillermo Oyarzun Date: Wed, 19 Feb 2025 16:04:46 +0100 Subject: [PATCH] feat(gpu): enable division in high level api --- .../src/high_level_api/integers/signed/ops.rs | 36 ++++++++++++++----- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/tfhe/src/high_level_api/integers/signed/ops.rs b/tfhe/src/high_level_api/integers/signed/ops.rs index 4a6f04aa55..44c018e189 100644 --- a/tfhe/src/high_level_api/integers/signed/ops.rs +++ b/tfhe/src/high_level_api/integers/signed/ops.rs @@ -514,9 +514,17 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices does not support division yet") - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let (q, r) = cuda_key.key.key.div_rem( + &*self.ciphertext.on_gpu(streams), + &*rhs.ciphertext.on_gpu(streams), + streams, + ); + ( + FheInt::::new(q, cuda_key.tag.clone()), + FheInt::::new(r, cuda_key.tag.clone()), + ) + }), }) } } @@ -847,9 +855,14 @@ generic_integer_impl_operation!( FheInt::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_cuda_key) => { - panic!("Division '/' is not yet supported by Cuda devices") - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let inner_result = + cuda_key + .key + .key + .div(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); + FheInt::new(inner_result, cuda_key.tag.clone()) + }), }) } }, @@ -893,9 +906,14 @@ generic_integer_impl_operation!( FheInt::new(inner_result, cpu_key.tag.clone()) }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_cuda_key) => { - panic!("Remainder/Modulo '%' is not yet supported by Cuda devices") - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let inner_result = + cuda_key + .key + .key + .rem(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); + FheInt::new(inner_result, cuda_key.tag.clone()) + }), }) } },