Skip to content

Commit

Permalink
feat(runtimes/nvidia): 实现 kernel 的编译和查询
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 4, 2024
1 parent d685b82 commit f5df3b4
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
3 changes: 1 addition & 2 deletions runtimes/nvidia/src/driver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ mod context;
mod device;
mod graph;
mod memory;
mod nvrtc;
pub mod nvrtc;
mod stream;

trait AsRaw<T> {
Expand All @@ -45,7 +45,6 @@ trait WithCtx {
unsafe fn ctx(&self) -> bindings::CUcontext;
}

pub(crate) use bindings::CUresult;
pub(crate) use context::{Context, ContextGuard};
pub(crate) use device::devices;
pub(crate) use graph::{ExecutableGraph, Graph};
Expand Down
61 changes: 61 additions & 0 deletions runtimes/nvidia/src/driver/nvrtc.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,77 @@
use super::{bindings as cuda, Context, ContextGuard};
use std::{
collections::{hash_map::Keys, HashMap},
ffi::{c_char, CStr, CString},
ptr::{null, null_mut},
sync::Arc,
sync::{Mutex, OnceLock},
};

static MODULES: OnceLock<Mutex<HashMap<String, Arc<Module>>>> = OnceLock::new();

pub(crate) fn compile<'a, I, U, V>(code: &str, symbols: I, ctx: &ContextGuard)
where
I: IntoIterator<Item = (U, V)>,
U: AsRef<str>,
V: AsRef<str>,
{
let symbols = symbols
.into_iter()
.map(|(k, v)| (k.as_ref().to_owned(), v.as_ref().to_owned()))
.collect::<HashMap<_, _>>();
// 先检查一遍并确保静态对象创建
let modules = if let Some(modules) = MODULES.get() {
if check_hold(&*modules.lock().unwrap(), symbols.keys()) {
return;
}
modules
} else {
MODULES.get_or_init(|| Default::default())
};
// 编译
let (module, log) = Module::from_src(code, ctx);
println!("{log}");
// 再上锁检查一遍
let module = Arc::new(module.unwrap());
let mut map = modules.lock().unwrap();
if !check_hold(&*map, symbols.keys()) {
for k in symbols.keys() {
// 确认指定的符号都存在
module.get_function(k);
map.insert(k.clone(), module.clone());
}
}
}

pub(crate) fn get_function(name: &str) -> Option<cuda::CUfunction> {
MODULES.get().and_then(|modules| {
modules
.lock()
.unwrap()
.get(name)
.map(|module| module.get_function(name))
})
}

fn check_hold(map: &HashMap<String, Arc<Module>>, symbols: Keys<'_, String, String>) -> bool {
let len = symbols.len();
let had = symbols.filter(|&k| map.contains_key(k)).count();
if had == len {
true
} else if had == 0 {
false
} else {
panic!()
}
}

struct Module {
ctx: Arc<Context>,
module: cuda::CUmodule,
}

unsafe impl Send for Module {}
unsafe impl Sync for Module {}

impl Drop for Module {
#[inline]
Expand Down

0 comments on commit f5df3b4

Please sign in to comment.