Skip to content

Commit

Permalink
Rewritten cu::DeviceMemory::operator* for more tolerant casting
Browse files Browse the repository at this point in the history
Currently, it is only allowed use the deference operator in case of managed memory. This commit
relaxes this requirement a bit by also allowing access to non-managed memory. This enables casts like this:

cu::DeviceMemory(1024) mem;
float* ptr = static_cast<float*>(mem);

therefore avoiding an intermediate cast to CUdeviceptr.
  • Loading branch information
wvbbreu committed Nov 6, 2024
1 parent 8437f14 commit f2e6d33
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
32 changes: 22 additions & 10 deletions include/cudawrappers/cu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,20 @@ inline void checkCudaCall(CUresult result) {
if (result != CUDA_SUCCESS) throw Error(result);
}

template <typename T>
inline void checkPointerAccess(const T &pointer) {
CUmemorytype memoryType;
checkCudaCall(cuPointerGetAttribute(
&memoryType, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer));

if (memoryType != CU_MEMORYTYPE_DEVICE &&
memoryType != CU_MEMORYTYPE_UNIFIED) {
throw std::runtime_error(
"Invalid memory type: only CU_MEMORYTYPE_DEVICE and "
"CU_MEMORYTYPE_UNIFIED are supported.");
}
}

inline void init(unsigned flags = 0) { checkCudaCall(cuInit(flags)); }

inline int driverGetVersion() {
Expand Down Expand Up @@ -632,19 +646,17 @@ class DeviceMemory : public Wrapper<CUdeviceptr> {
{
return &_obj;
}
void *parameter_copy() { return reinterpret_cast<void *>(_obj); }

template <typename T>
operator T *() {
int data;
checkCudaCall(
cuPointerGetAttribute(&data, CU_POINTER_ATTRIBUTE_IS_MANAGED, _obj));
if (data) {
return reinterpret_cast<T *>(_obj);
} else {
throw std::runtime_error(
"Cannot return memory of type CU_MEMORYTYPE_DEVICE as pointer.");
}
checkPointerAccess(_obj);
return reinterpret_cast<T *>(_obj);
}

template <typename T>
operator T *() const {
checkPointerAccess(_obj);
return reinterpret_cast<T const *>(_obj);
}

size_t size() const { return _size; }
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ TEST_CASE("Test cu::DeviceMemory", "[devicememory]") {
SECTION("Test cu::DeviceMemory with CU_MEMORYTYPE_DEVICE as host pointer") {
cu::DeviceMemory mem(sizeof(float), CU_MEMORYTYPE_DEVICE, 0);
float* ptr;
CHECK_THROWS(ptr = mem);
CHECK_NOTHROWS(ptr = mem);
}

SECTION("Test cu::DeviceMemory with CU_MEMORYTYPE_UNIFIED as host pointer") {
Expand Down

0 comments on commit f2e6d33

Please sign in to comment.