diff --git a/runtimes/nvidia/src/driver/graph.rs b/runtimes/nvidia/src/driver/graph.rs index c0b089e..454048d 100644 --- a/runtimes/nvidia/src/driver/graph.rs +++ b/runtimes/nvidia/src/driver/graph.rs @@ -14,10 +14,10 @@ impl Drop for Graph { } impl Graph { - pub fn new(&self) -> Graph { + pub fn new() -> Self { let mut graph: cuda::CUgraph = null_mut(); cuda::invoke!(cuGraphCreate(&mut graph, 0)); - Graph { + Self { graph, first_node: null_mut(), last_node: null_mut(), @@ -140,12 +140,13 @@ fn params_memcpy3d( len: usize, ty: MemcpyType, ) -> cuda::CUDA_MEMCPY3D { + use cuda::CUmemorytype::*; let mut ans = cuda::CUDA_MEMCPY3D { srcXInBytes: 0, srcY: 0, srcZ: 0, srcLOD: 0, - srcMemoryType: cuda::CUmemorytype_enum::CU_MEMORYTYPE_DEVICE, + srcMemoryType: CU_MEMORYTYPE_DEVICE, srcHost: null_mut(), srcDevice: 0, srcArray: null_mut(), @@ -156,7 +157,7 @@ fn params_memcpy3d( dstY: 0, dstZ: 0, dstLOD: 0, - dstMemoryType: cuda::CUmemorytype_enum::CU_MEMORYTYPE_DEVICE, + dstMemoryType: CU_MEMORYTYPE_DEVICE, dstHost: null_mut(), dstDevice: 0, dstArray: null_mut(), @@ -173,19 +174,19 @@ fn params_memcpy3d( ans.dstDevice = dst as _; } MemcpyType::H2H => { - ans.srcMemoryType = cuda::CUmemorytype_enum::CU_MEMORYTYPE_HOST; + ans.srcMemoryType = CU_MEMORYTYPE_HOST; ans.srcHost = src as _; - ans.dstMemoryType = cuda::CUmemorytype_enum::CU_MEMORYTYPE_HOST; + ans.dstMemoryType = CU_MEMORYTYPE_HOST; ans.dstHost = dst as _; } MemcpyType::H2D => { - ans.srcMemoryType = cuda::CUmemorytype_enum::CU_MEMORYTYPE_HOST; + ans.srcMemoryType = CU_MEMORYTYPE_HOST; ans.srcHost = src as _; ans.dstDevice = dst as _; } MemcpyType::D2H => { ans.srcDevice = src as _; - ans.dstMemoryType = cuda::CUmemorytype_enum::CU_MEMORYTYPE_HOST; + ans.dstMemoryType = CU_MEMORYTYPE_HOST; ans.dstHost = dst as _; } }; diff --git a/runtimes/nvidia/src/driver/mod.rs b/runtimes/nvidia/src/driver/mod.rs index 52219df..d08daeb 100644 --- a/runtimes/nvidia/src/driver/mod.rs +++ b/runtimes/nvidia/src/driver/mod.rs @@ -21,6 +21,11 @@ mod graph; mod memory; mod stream; +#[inline(always)] +pub(crate) fn init() { + bindings::invoke!(cuInit(0)); +} + trait AsRaw { unsafe fn as_raw(&self) -> T; } @@ -29,6 +34,7 @@ trait WithCtx { unsafe fn ctx(&self) -> bindings::CUcontext; } -pub(crate) use graph::Graph; +pub(crate) use context::{Context, ContextGuard}; +pub(crate) use device::devices; +pub(crate) use graph::{ExecutableGraph, Graph}; pub(crate) use memory::Blob; -pub(crate) use stream::Stream; diff --git a/runtimes/nvidia/src/graph.rs b/runtimes/nvidia/src/graph.rs index e338ef7..b99f96e 100644 --- a/runtimes/nvidia/src/graph.rs +++ b/runtimes/nvidia/src/graph.rs @@ -1,15 +1,63 @@ -use crate::driver; +use crate::driver::{self, ContextGuard}; use graph_topo::GraphTopo; +use stack_calculator::{flat, unidir, RealtimeCalculator}; +use std::sync::Arc; pub struct Graph { - graph: driver::Graph, + ctx: Arc, + graph: driver::ExecutableGraph, topology: GraphTopo, edges: Vec, static_mem: driver::Blob, stack: driver::Blob, } -enum MemOffset { - Static(usize), - Stack(usize), +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +#[repr(transparent)] +struct MemOffset(usize); + +impl MemOffset { + const INVALID: MemOffset = MemOffset(usize::MAX); +} + +impl Graph { + #[inline] + pub fn new(src: &computation::Graph, dev: usize) -> Self { + driver::devices()[dev] + .context() + .apply(|ctx| ctx.runtime_graph(src)) + } + + #[inline] + pub fn run(&self) { + self.ctx.apply(|ctx| { + let stream = ctx.stream(); + unsafe { self.graph.launch_on(&stream) } + }) + } +} + +impl ContextGuard<'_> { + pub fn runtime_graph(&self, src: &computation::Graph) -> Graph { + let src = &src.0; + + let mut flat = flat::RealtimeCalculator::default(); + let mut unidir = unidir::RealtimeCalculator::default(); + + let mut edges = vec![MemOffset::INVALID; src.edges.len()]; + + driver::init(); + let graph = driver::Graph::new(); + + let mut static_mem = self.malloc(flat.peak()); + + Graph { + ctx: self.clone_ctx(), + graph: graph.instantiate(self), + topology: src.topology.clone(), + edges, + static_mem, + stack: self.malloc(unidir.peak()), + } + } } diff --git a/runtimes/nvidia/src/lib.rs b/runtimes/nvidia/src/lib.rs index 0405067..02ac6ec 100644 --- a/runtimes/nvidia/src/lib.rs +++ b/runtimes/nvidia/src/lib.rs @@ -2,7 +2,7 @@ #![cfg(detected_cuda)] -use graph_topo::GraphTopo; - mod driver; mod graph; + +pub use graph::Graph; diff --git a/stack-calculator/src/flat.rs b/stack-calculator/src/flat.rs index 327f11b..4b56712 100644 --- a/stack-calculator/src/flat.rs +++ b/stack-calculator/src/flat.rs @@ -1,10 +1,12 @@ -use crate::{align, Calculator}; -use std::{alloc::Layout, collections::HashSet}; +//! 平铺对象的栈计算器,包括一个实时的版本和一个非实时的版本。 + +use crate::RealtimeCalculator as _; +use std::{alloc::Layout, collections::HashSet, ops::Range}; /// 平铺对象的栈计算器。 -pub struct FlatCalculator; +pub struct Calculator; -impl Calculator for FlatCalculator { +impl crate::Calculator for Calculator { fn calculate( self, topology: &graph_topo::GraphTopo, @@ -12,26 +14,44 @@ impl Calculator for FlatCalculator { ) -> usize { let global_outputs = HashSet::::from_iter(topology.global_outputs()); - let mut ans = 0; + let mut rt_cal = RealtimeCalculator::default(); for (i, _inputs, outputs) in topology { for i in outputs { if !global_outputs.contains(&i) { - manager.set_tensor_offset(i, put_obj(&mut ans, manager.tensor_layout(i))); + manager.set_tensor_offset(i, rt_cal.alloc(manager.tensor_layout(i)).start); } } - manager.set_workspace_offset(i, put_obj(&mut ans, manager.workspace_layout(i))); + manager.set_workspace_offset(i, rt_cal.alloc(manager.workspace_layout(i)).start); } - ans + rt_cal.peak() } } -#[inline(always)] -fn put_obj(size: &mut usize, obj: Layout) -> usize { - if obj.size() == 0 { - *size - } else { - let offset = align(*size, obj.align()); - *size = offset + obj.size(); - offset +/// 实时的平铺对象的栈计算器。 +#[derive(Default, Debug)] +pub struct RealtimeCalculator { + pos: usize, +} + +impl crate::RealtimeCalculator for RealtimeCalculator { + fn alloc(&mut self, obj: Layout) -> Range { + if obj.size() == 0 { + return 0..0; + } + + let start = crate::align(self.pos, obj.align()); + self.pos = start + obj.size(); + + start..self.pos + } + + #[inline] + fn free(&mut self, _range: Range) { + // Nothing to do. + } + + #[inline] + fn peak(&self) -> usize { + self.pos } } diff --git a/stack-calculator/src/lib.rs b/stack-calculator/src/lib.rs index 082a8dd..92715ae 100644 --- a/stack-calculator/src/lib.rs +++ b/stack-calculator/src/lib.rs @@ -10,14 +10,11 @@ #![deny(warnings, missing_docs)] -mod flat; -mod unidir; +pub mod flat; +pub mod unidir; use graph_topo::GraphTopo; -use std::alloc::Layout; - -pub use flat::FlatCalculator; -pub use unidir::UnidirCalculator; +use std::{alloc::Layout, ops::Range}; /// 栈计算器。 pub trait Calculator { @@ -25,6 +22,18 @@ pub trait Calculator { fn calculate(self, topology: &GraphTopo, manager: &mut impl Manager) -> usize; } +/// 实时栈计算器。 +pub trait RealtimeCalculator { + /// 分配满足 `obj` 要求的空间。 + fn alloc(&mut self, obj: Layout) -> Range; + + /// 释放 `range` 范围内的空间。 + fn free(&mut self, range: Range); + + /// 获取栈空间的历史峰值。 + fn peak(&self) -> usize; +} + /// 栈计算管理器。 pub trait Manager { /// 获取张量的数量。 diff --git a/stack-calculator/src/unidir.rs b/stack-calculator/src/unidir.rs index 3316439..a35262d 100644 --- a/stack-calculator/src/unidir.rs +++ b/stack-calculator/src/unidir.rs @@ -1,4 +1,6 @@ -use crate::{align, Calculator}; +//! 单向扩容的栈计算器,包括一个实时的版本和一个非实时的版本。 + +use crate::RealtimeCalculator as _; use std::{ alloc::Layout, cmp::Ordering, @@ -7,9 +9,9 @@ use std::{ }; /// 单向扩容的栈计算器。 -pub struct UnidirCalculator; +pub struct Calculator; -impl Calculator for UnidirCalculator { +impl crate::Calculator for Calculator { fn calculate( self, topology: &graph_topo::GraphTopo, @@ -59,8 +61,9 @@ impl Calculator for UnidirCalculator { } } +/// 实时的单向扩容栈计算器。 #[derive(Default, Debug)] -struct RealtimeCalculator { +pub struct RealtimeCalculator { used: usize, peak: usize, @@ -69,7 +72,7 @@ struct RealtimeCalculator { free_tail_head: HashMap, } -impl RealtimeCalculator { +impl crate::RealtimeCalculator for RealtimeCalculator { fn alloc(&mut self, obj: Layout) -> Range { if obj.size() == 0 { return 0..0; @@ -79,7 +82,7 @@ impl RealtimeCalculator { if let Some(&HeadTail(Range { start, end })) = self .free_headtails .range(HeadTail(0..obj.size())..) - .find(|&HeadTail(r)| r.end - align(r.start, obj.align()) >= obj.size()) + .find(|&HeadTail(r)| r.end - crate::align(r.start, obj.align()) >= obj.size()) { self.free_headtails.remove(&HeadTail(start..end)); self.free_head_tail.remove(&start); @@ -126,10 +129,12 @@ impl RealtimeCalculator { } #[inline] - const fn peak(&self) -> usize { + fn peak(&self) -> usize { self.peak } +} +impl RealtimeCalculator { #[inline] fn insert(&mut self, start: usize, end: usize) { if end > start { @@ -141,7 +146,7 @@ impl RealtimeCalculator { #[inline(always)] const fn head_tail(start: usize, obj: Layout) -> (usize, usize) { - let head = align(start, obj.align()); + let head = crate::align(start, obj.align()); (head, head + obj.size()) } }