Skip to content

Commit

Permalink
chore(gpu): use wrapping byte add, update rust msrv
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Jun 27, 2024
1 parent ee1c904 commit 773adcc
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 27 deletions.
2 changes: 1 addition & 1 deletion tfhe/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ exclude = [
"/js_on_wasm_tests/",
"/web_wasm_parallel_tests/",
]
rust-version = "1.73"
rust-version = "1.75"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ impl<T: UnsignedInteger> CudaLweCiphertextList<T> {
streams.ptr[0],
streams.gpu_indexes[0],
);
ptr = ptr.cast::<u8>().add(size).cast();
ptr = ptr.wrapping_byte_add(size);
for list in cuda_ciphertexts_list_vec {
cuda_memcpy_async_gpu_to_gpu(
ptr,
Expand All @@ -129,7 +129,7 @@ impl<T: UnsignedInteger> CudaLweCiphertextList<T> {
streams.ptr[0],
streams.gpu_indexes[0],
);
ptr = ptr.cast::<u8>().add(size).cast();
ptr = ptr.wrapping_byte_add(size);
}
}

Expand Down
16 changes: 4 additions & 12 deletions tfhe/src/core_crypto/gpu/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,8 @@ where
None
} else {
// Shift ptr
let shifted_ptr: *mut c_void = unsafe {
self.ptrs[gpu_index as usize]
.cast::<u8>()
.add(start * std::mem::size_of::<T>())
.cast()
};
let shifted_ptr: *mut c_void =
self.ptrs[gpu_index as usize].wrapping_byte_add(start * std::mem::size_of::<T>());

// Compute the length
let new_len = end - start + 1;
Expand Down Expand Up @@ -212,12 +208,8 @@ where
let new_len_1 = mid;
let new_len_2 = self.lengths[gpu_index as usize] - mid;
// Shift ptr
let shifted_ptr: *mut c_void = unsafe {
self.ptrs[gpu_index as usize]
.cast::<u8>()
.add(mid * std::mem::size_of::<T>())
.cast()
};
let shifted_ptr: *mut c_void =
self.ptrs[gpu_index as usize].wrapping_byte_add(mid * std::mem::size_of::<T>());

// Create the slice
(
Expand Down
16 changes: 4 additions & 12 deletions tfhe/src/core_crypto/gpu/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,8 @@ impl<T: Numeric> CudaVec<T> {
None
} else {
// Shift ptr
let shifted_ptr: *mut c_void = unsafe {
self.ptr[gpu_index as usize]
.cast::<u8>()
.add(start * std::mem::size_of::<T>())
.cast()
};
let shifted_ptr: *mut c_void =
self.ptr[gpu_index as usize].wrapping_byte_add(start * std::mem::size_of::<T>());

// Compute the length
let new_len = end - start + 1;
Expand All @@ -397,12 +393,8 @@ impl<T: Numeric> CudaVec<T> {
None
} else {
// Shift ptr
let shifted_ptr: *mut c_void = unsafe {
self.ptr[gpu_index as usize]
.cast::<u8>()
.add(start * std::mem::size_of::<T>())
.cast()
};
let shifted_ptr: *mut c_void =
self.ptr[gpu_index as usize].wrapping_byte_add(start * std::mem::size_of::<T>());

// Compute the length
let new_len = end - start + 1;
Expand Down

0 comments on commit 773adcc

Please sign in to comment.