From 9edfc057fbf3d56c297a150185e97d73b340f353 Mon Sep 17 00:00:00 2001 From: Bram Veenboer Date: Mon, 2 Oct 2023 14:17:04 +0200 Subject: [PATCH] Move asynchronous zero (#226) --- CHANGELOG.md | 1 + include/cudawrappers/cu.hpp | 10 ++++------ tests/test_cu.cpp | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d06c35b..033c2f79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). ### Added ### Changed - Made the library header only +- Moved asynchronous `::zero` from `Device` to `Stream` - Replaced `include_cuda_code` helper with `target_embed_source` ### Removed diff --git a/include/cudawrappers/cu.hpp b/include/cudawrappers/cu.hpp index f66bc1bd..810dc917 100644 --- a/include/cudawrappers/cu.hpp +++ b/include/cudawrappers/cu.hpp @@ -415,8 +415,6 @@ class DeviceMemory : public Wrapper { void zero(size_t size) { checkCudaCall(cuMemsetD8(_obj, 0, size)); } - void zero(size_t size, Stream &stream); - const void *parameter() const // used to construct parameter list for launchKernel(); { @@ -484,6 +482,10 @@ class Stream : public Wrapper { checkCudaCall(cuMemPrefetchAsync(devPtr, size, dstDevice, _obj)); } + void zero(CUdeviceptr devPtr, size_t size) { + checkCudaCall(cuMemsetD8Async(devPtr, 0, size, _obj)); + } + void launchKernel(Function &function, unsigned gridX, unsigned gridY, unsigned gridZ, unsigned blockX, unsigned blockY, unsigned blockZ, unsigned sharedMemBytes, @@ -534,10 +536,6 @@ class Stream : public Wrapper { } }; -inline void DeviceMemory::zero(size_t size, Stream &stream) { - checkCudaCall(cuMemsetD8Async(_obj, 0, size, stream)); -} - inline void Event::record(Stream &stream) { checkCudaCall(cuEventRecord(_obj, stream._obj)); } diff --git a/tests/test_cu.cpp b/tests/test_cu.cpp index d9047651..4ab3e772 100644 --- a/tests/test_cu.cpp +++ b/tests/test_cu.cpp @@ -87,7 +87,7 @@ TEST_CASE("Test zeroing cu::DeviceMemory", "[zero]") { cu::Stream stream; stream.memcpyHtoDAsync(mem, src, size); - mem.zero(size, stream); + stream.zero(mem, size); stream.memcpyDtoHAsync(tgt, mem, size); stream.synchronize();