Skip to content

Commit

Permalink
Moved checkPointerAccess method to Wrapper<T>
Browse files Browse the repository at this point in the history
  • Loading branch information
wvbbreu committed Nov 7, 2024
1 parent 1ee98d8 commit 7e83d21
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions include/cudawrappers/cu.hpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <sys/resource.h>
#if !defined CU_WRAPPER_H
#define CU_WRAPPER_H

Expand Down Expand Up @@ -45,20 +46,6 @@ inline void checkCudaCall(CUresult result) {
if (result != CUDA_SUCCESS) throw Error(result);
}

template <typename T>
inline void checkPointerAccess(const T &pointer) {
CUmemorytype memoryType;
checkCudaCall(cuPointerGetAttribute(
&memoryType, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer));

if (memoryType != CU_MEMORYTYPE_DEVICE &&
memoryType != CU_MEMORYTYPE_UNIFIED) {
throw std::runtime_error(
"Invalid memory type: only CU_MEMORYTYPE_DEVICE and "
"CU_MEMORYTYPE_UNIFIED are supported.");
}
}

inline void init(unsigned flags = 0) { checkCudaCall(cuInit(flags)); }

inline int driverGetVersion() {
Expand Down Expand Up @@ -108,6 +95,22 @@ class Wrapper {

explicit Wrapper(T &obj) : _obj(obj) {}

template <CUmemorytype... AllowedMemoryTypes>
inline void checkPointerAccess(const CUdeviceptr &pointer) const {
CUmemorytype memoryType;
checkCudaCall(cuPointerGetAttribute(
&memoryType, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer));

// Check if the memoryType is one of the allowed memory types
bool isAllowed = false;
for (auto allowedType : {AllowedMemoryTypes...}) {
if (memoryType == allowedType) return;
}

throw std::runtime_error(
"Invalid memory type: allowed types are not matched.");
}

T _obj{};
std::shared_ptr<T> manager;
};
Expand Down Expand Up @@ -649,13 +652,13 @@ class DeviceMemory : public Wrapper<CUdeviceptr> {

template <typename T>
operator T *() {
checkPointerAccess(_obj);
checkPointerAccess<CU_MEMORYTYPE_DEVICE, CU_MEMORYTYPE_UNIFIED>(_obj);
return reinterpret_cast<T *>(_obj);
}

template <typename T>
operator T *() const {
checkPointerAccess(_obj);
checkPointerAccess<CU_MEMORYTYPE_DEVICE, CU_MEMORYTYPE_UNIFIED>(_obj);
return reinterpret_cast<T const *>(_obj);
}

Expand Down

0 comments on commit 7e83d21

Please sign in to comment.