Skip to content

Commit

Permalink
Merge pull request #9 from edgenai/fix/cuda-context
Browse files Browse the repository at this point in the history
Fix/cuda context
  • Loading branch information
pedro-devv authored Apr 2, 2024
2 parents 00e6bb6 + 3bc243f commit f10c977
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 21 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/memonitor-sys/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "memonitor-sys"
description = "Automatically generated bindings for some of memonitor's backends."
version = "0.2.1"
version = "0.2.2"
authors = ["Pedro Valente <[email protected]>"]
license = "Apache-2.0"
repository = "https://github.com/edgenai/memonitor"
Expand Down
48 changes: 35 additions & 13 deletions crates/memonitor-sys/cuda/src/memonitor.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ typedef CUresult (*cuCtxCreate_type)(CUcontext *, unsigned int, CUdevice);

typedef CUresult (*cuCtxDestroy_type)(CUcontext);

typedef CUresult (*cuCtxPopCurrent_type)(CUcontext *);

typedef CUresult (*cuCtxPushCurrent_type)(CUcontext ctx);

typedef CUresult (*cuCtxSetCurrent_type)(CUcontext);

typedef CUresult (*cuDeviceGetCount_type)(int *);
Expand All @@ -42,17 +46,19 @@ struct Device {
CUdevice_v1 inner;
};

mod_type module = NULL;
static mod_type module = NULL;

cuInit_type cuInit = NULL;
cuCtxCreate_type cuCtxCreate = NULL;
cuCtxDestroy_type cuCtxDestroy = NULL;
cuCtxSetCurrent_type cuCtxSetCurrent = NULL;
cuDeviceGetCount_type cuDeviceGetCount = NULL;
cuDeviceGet_type cuDeviceGet = NULL;
cuDeviceGetName_type cuDeviceGetName = NULL;
cuDeviceTotalMem_type cuDeviceTotalMem = NULL;
cuMemGetInfo_type cuMemGetInfo = NULL;
static cuInit_type cuInit = NULL;
static cuCtxCreate_type cuCtxCreate = NULL;
static cuCtxDestroy_type cuCtxDestroy = NULL;
static cuCtxPopCurrent_type cuCtxPopCurrent = NULL;
static cuCtxPushCurrent_type cuCtxPushCurrent = NULL;
static cuCtxSetCurrent_type cuCtxSetCurrent = NULL;
static cuDeviceGetCount_type cuDeviceGetCount = NULL;
static cuDeviceGet_type cuDeviceGet = NULL;
static cuDeviceGetName_type cuDeviceGetName = NULL;
static cuDeviceTotalMem_type cuDeviceTotalMem = NULL;
static cuMemGetInfo_type cuMemGetInfo = NULL;


int cu_init() {
Expand All @@ -65,6 +71,8 @@ int cu_init() {
cuInit = (cuInit_type) GetProcAddress(module, "cuInit");
cuCtxCreate = (cuCtxCreate_type) GetProcAddress(module, "cuCtxCreate");
cuCtxDestroy = (cuCtxDestroy_type) GetProcAddress(module, "cuCtxDestroy");
cuCtxPopCurrent = (cuCtxPopCurrent_type) GetProcAddress(module, "cuCtxPopCurrent");
cuCtxPushCurrent = (cuCtxPushCurrent_type) GetProcAddress(module, "cuCtxPushCurrent");
cuCtxSetCurrent = (cuCtxSetCurrent_type) GetProcAddress(module, "cuCtxSetCurrent");
cuDeviceGetCount = (cuDeviceGetCount_type) GetProcAddress(module, "cuDeviceGetCount");
cuDeviceGet = (cuDeviceGet_type) GetProcAddress(module, "cuDeviceGet");
Expand All @@ -82,6 +90,8 @@ int cu_init() {
cuInit = (cuInit_type) dlsym(module, "cuInit");
cuCtxCreate = (cuCtxCreate_type) dlsym(module, "cuCtxCreate");
cuCtxDestroy = (cuCtxDestroy_type) dlsym(module, "cuCtxDestroy");
cuCtxPopCurrent = (cuCtxPopCurrent_type) dlsym(module, "cuCtxPopCurrent");
cuCtxPushCurrent = (cuCtxPushCurrent_type) dlsym(module, "cuCtxPushCurrent");
cuCtxSetCurrent = (cuCtxSetCurrent_type) dlsym(module, "cuCtxSetCurrent");
cuDeviceGetCount = (cuDeviceGetCount_type) dlsym(module, "cuDeviceGetCount");
cuDeviceGet = (cuDeviceGet_type) dlsym(module, "cuDeviceGet");
Expand Down Expand Up @@ -155,6 +165,13 @@ struct cu_Devices cu_list_devices() {
free(ctx_handles);
return invalid_devices;
}

res = cuCtxPopCurrent(&ctx_handles[d]);
if (res != 0) {
free(device_handles);
free(ctx_handles);
return invalid_devices;
}
}

struct cu_Devices devices = {
Expand Down Expand Up @@ -187,7 +204,7 @@ struct cu_DeviceRef cu_get_device(struct cu_Devices *devices, uint32_t index) {

CUdevice *cast_devices = devices->devices_handle;
CUcontext *cast_ctxs = devices->ctx_handle;
struct cu_DeviceRef ref = {cast_devices[index], cast_ctxs[index]};
struct cu_DeviceRef ref = {cast_devices[index], &cast_ctxs[index]};
return ref;
}

Expand Down Expand Up @@ -220,8 +237,8 @@ struct cu_DeviceMemoryProperties cu_device_memory_properties(struct cu_DeviceRef
return invalid_properties;
}

CUcontext cast_ctx = device.ctx_handle;
CUresult res = cuCtxSetCurrent(cast_ctx);
CUcontext *cast_ctx = device.ctx_handle;
CUresult res = cuCtxPushCurrent(*cast_ctx);
if (res != 0) {
return invalid_properties;
}
Expand All @@ -234,6 +251,11 @@ struct cu_DeviceMemoryProperties cu_device_memory_properties(struct cu_DeviceRef
return invalid_properties;
}

res = cuCtxPopCurrent(cast_ctx);
if (res != 0) {
return invalid_properties;
}

struct cu_DeviceMemoryProperties props = {
free_memory,
total_memory - free_memory,
Expand Down
6 changes: 3 additions & 3 deletions crates/memonitor-sys/vulkan/src/memonitor.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
#define ARRAY_LEN(array) sizeof(array) / sizeof(array[0])

#ifdef USE_VALIDATION_LAYERS
const char *layer_names[1] = {"VK_LAYER_KHRONOS_validation"};
static const char *layer_names[1] = {"VK_LAYER_KHRONOS_validation"};
#else
//const char *layer_names[0] = {};
//static const char *layer_names[0] = {};
#endif // USE_VALIDATION_LAYERS

const char *extension_names[1] = {"VK_KHR_get_physical_device_properties2"};
static const char *extension_names[1] = {"VK_KHR_get_physical_device_properties2"};

/**
* Check if the required layers are supported locally.
Expand Down
4 changes: 2 additions & 2 deletions crates/memonitor/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "memonitor"
description = "Query CPU and GPU memory information in a portable way."
version = "0.2.1"
version = "0.2.2"
authors = ["Pedro Valente <[email protected]>"]
license = "Apache-2.0"
repository = "https://github.com/edgenai/memonitor"
Expand All @@ -10,7 +10,7 @@ edition = "2021"
publish = true

[dependencies]
memonitor-sys = { path = "../memonitor-sys", version = "0.2.1", default-features = false }
memonitor-sys = { path = "../memonitor-sys", version = "0.2.2", default-features = false }
once_cell = "1.19.0"
sysinfo = "0.30.7"
tracing = "0.1.40"
Expand Down
3 changes: 3 additions & 0 deletions crates/memonitor/src/cpu.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashSet;
use std::sync::{Arc, RwLock};
use sysinfo::{CpuRefreshKind, MemoryRefreshKind, RefreshKind, System};
use tracing::debug;

use crate::{BackendHandle, BackendId, DeviceHandle, DeviceKind, MemoryStats, CPU_NAME};

Expand All @@ -10,6 +11,8 @@ pub(super) struct Host {

impl Host {
pub(super) fn init() -> (Self, Vec<Cpu>) {
debug!("Attempting to load CPU monitor");

let mut system = System::new_with_specifics(
RefreshKind::new()
.with_cpu(CpuRefreshKind::default())
Expand Down
3 changes: 3 additions & 0 deletions crates/memonitor/src/cuda.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::ffi::CStr;
use std::ptr::addr_of_mut;

use tracing::debug;

use memonitor_sys::cuda;
Expand All @@ -16,6 +17,8 @@ unsafe impl Sync for Cuda {}

impl Cuda {
pub(super) fn init() -> Option<(Self, Vec<CudaDevice>)> {
debug!("Attempting to load CUDA monitor");

let res = unsafe { cuda::init() };

if res == 0 {
Expand Down
2 changes: 2 additions & 0 deletions crates/memonitor/src/vulkan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ unsafe impl Sync for Vulkan {}

impl Vulkan {
pub(super) fn init() -> Option<(Self, Vec<VulkanDevice>)> {
debug!("Attempting to load Vulkan monitor");

let res = unsafe { vulkan::init() };
if res == 0 {
let mut c_devices = unsafe { vulkan::list_devices() };
Expand Down

0 comments on commit f10c977

Please sign in to comment.