diff --git a/catkit_core/CudaSharedMemory.h b/catkit_core/CudaSharedMemory.h index 98899e2d..a5e433f5 100644 --- a/catkit_core/CudaSharedMemory.h +++ b/catkit_core/CudaSharedMemory.h @@ -1,6 +1,8 @@ #ifndef CUDA_SHARED_MEMORY_H #define CUDA_SHARED_MEMORY_H +#include "Memory.h" + #include #ifdef HAVE_CUDA @@ -11,7 +13,7 @@ typedef cudaIpcMemHandle_t CudaIpcHandle; typedef char CudaIpcHandle[64]; #endif -class CudaSharedMemory +class CudaSharedMemory : public Memory { private: CudaSharedMemory(const CudaIpcHandle &ipc_handle, void *device_pointer=nullptr); @@ -22,7 +24,7 @@ class CudaSharedMemory static std::shared_ptr Create(size_t num_bytes_in_buffer); static std::shared_ptr Open(const CudaIpcHandle &ipc_handle); - void *GetAddress(); + void *GetAddress(std::size_t offset = 0) override; }; #endif // CUDA_SHARED_MEMORY_H diff --git a/catkit_core/Memory.h b/catkit_core/Memory.h new file mode 100644 index 00000000..9f20e9ed --- /dev/null +++ b/catkit_core/Memory.h @@ -0,0 +1,16 @@ +#ifndef MEMORY_H +#define MEMORY_H + +#include + +class Memory +{ +public: + virtual ~Memory() + { + } + + virtual void *GetAddress(std::size_t offset = 0) = 0; +}; + +#endif // MEMORY_H diff --git a/catkit_core/SharedMemory.cpp b/catkit_core/SharedMemory.cpp index a469b9c4..c8b8d356 100644 --- a/catkit_core/SharedMemory.cpp +++ b/catkit_core/SharedMemory.cpp @@ -82,7 +82,7 @@ SharedMemory::SharedMemory(const std::string &id, FileObject file, bool is_owner throw std::runtime_error("Something went wrong while mapping shared memory file."); } -void *SharedMemory::GetAddress() +void *SharedMemory::GetAddress(std::size_t offset) { - return m_Buffer; + return static_cast(m_Buffer) + offset; } diff --git a/catkit_core/SharedMemory.h b/catkit_core/SharedMemory.h index aa6764bd..ca78b51c 100644 --- a/catkit_core/SharedMemory.h +++ b/catkit_core/SharedMemory.h @@ -1,6 +1,8 @@ #ifndef SHARED_MEMORY_H #define SHARED_MEMORY_H +#include "Memory.h" + #include #include @@ -21,7 +23,7 @@ typedef int FileObject; #endif -class SharedMemory +class SharedMemory : public Memory { private: SharedMemory(const std::string &id, FileObject file, bool is_owner); @@ -32,7 +34,7 @@ class SharedMemory static std::shared_ptr Create(const std::string &id, size_t num_bytes_in_buffer); static std::shared_ptr Open(const std::string &id); - void *GetAddress(); + void *GetAddress(std::size_t offset = 0) override; private: std::string m_Id;