diff --git a/include/cudawrappers/cu.hpp b/include/cudawrappers/cu.hpp index a6528f1f..e53171da 100644 --- a/include/cudawrappers/cu.hpp +++ b/include/cudawrappers/cu.hpp @@ -436,7 +436,8 @@ class Stream : public Wrapper { checkCudaCall(cuMemcpyAsync(dstPtr, srcPtr, size, _obj)); } - void memPrefetchAsync(CUdeviceptr devPtr, size_t size, CUdevice dstDevice) { + void memPrefetchAsync(CUdeviceptr devPtr, size_t size, + CUdevice dstDevice = CU_DEVICE_CPU) { checkCudaCall(cuMemPrefetchAsync(devPtr, size, dstDevice, _obj)); } diff --git a/tests/test_vector_add.cpp b/tests/test_vector_add.cpp index 5de62cc9..de42b38d 100644 --- a/tests/test_vector_add.cpp +++ b/tests/test_vector_add.cpp @@ -120,7 +120,7 @@ TEST_CASE("Vector add") { stream.memPrefetchAsync(d_a, bytesize, device); stream.memPrefetchAsync(d_b, bytesize, device); stream.launchKernel(function, 1, 1, 1, N, 1, 1, 0, parameters); - stream.memPrefetchAsync(d_c, bytesize, CU_DEVICE_CPU); + stream.memPrefetchAsync(d_c, bytesize); stream.synchronize(); check_arrays_equal(h_c, reference_c.data(), N);