Skip to content

Commit 8997de6

Browse files
committed
feat(runtimes/nvidia): 实现构图中的显存分配
Signed-off-by: YdrMaster <[email protected]>
1 parent 54594c9 commit 8997de6

File tree

4 files changed

+243
-50
lines changed

4 files changed

+243
-50
lines changed

runtimes/nvidia/src/driver/memory.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,20 @@ impl ContextGuard<'_> {
1919
}
2020
}
2121

22+
impl Stream<'_> {
23+
#[inline]
24+
pub fn malloc(&self, size: usize) -> DevicePtr {
25+
let mut ptr: cuda::CUdeviceptr = 0;
26+
cuda::invoke!(cuMemAllocAsync(&mut ptr, size, self.as_raw()));
27+
DevicePtr(ptr)
28+
}
29+
30+
#[inline]
31+
pub fn free(&self, ptr: DevicePtr) {
32+
cuda::invoke!(cuMemFreeAsync(ptr.0, self.as_raw()));
33+
}
34+
}
35+
2236
impl Drop for DevicePtr {
2337
#[inline]
2438
fn drop(&mut self) {

runtimes/nvidia/src/graph.rs

Lines changed: 174 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
use crate::driver::{self, ContextGuard};
2-
use computation::Tensor;
3-
use graph_topo::GraphTopo;
1+
use crate::{
2+
driver::{self, ContextGuard},
3+
kernel::{GraphBuilder, GraphUser, Resources},
4+
};
45
use stack_calculator::{flat, unidir, RealtimeCalculator};
56
use std::{alloc::Layout, collections::BTreeSet, sync::Arc};
67

78
pub struct Graph {
89
ctx: Arc<driver::Context>,
9-
graph: driver::ExecutableGraph,
10-
topology: GraphTopo,
11-
edges: Vec<MemOffset>,
10+
executable: driver::ExecutableGraph,
11+
#[allow(unused)] // stay here to keep resource lifetime
12+
resources: Resources,
1213
static_mem: driver::DevicePtr,
1314
stack: driver::DevicePtr,
15+
offsets: graph_topo::Graph<usize, MemOffset>,
1416
}
1517

1618
impl Drop for Graph {
@@ -34,23 +36,23 @@ impl Graph {
3436
pub fn run(&self) {
3537
self.ctx.apply(|ctx| {
3638
let stream = ctx.stream();
37-
unsafe { self.graph.launch_on(&stream) }
39+
unsafe { self.executable.launch_on(&stream) }
3840
})
3941
}
4042

4143
#[inline]
4244
pub fn copy_in_one<T>(&mut self, i: usize, data: &[T]) {
43-
let i = self.topology.global_inputs().nth(i).unwrap();
44-
let offset = self.edges[i].offset();
45+
let i = self.offsets.topology.global_inputs().nth(i).unwrap();
46+
let offset = self.offsets.edges[i].offset();
4547
self.ctx.apply(|ctx| unsafe {
4648
self.static_mem.copy_in(offset, data, ctx);
4749
});
4850
}
4951

5052
#[inline]
5153
pub fn copy_out_one<T>(&mut self, i: usize, data: &mut [T]) {
52-
let i = self.topology.global_outputs()[i];
53-
let offset = self.edges[i as usize].offset();
54+
let i = self.offsets.topology.global_outputs()[i];
55+
let offset = self.offsets.edges[i as usize].offset();
5456
self.ctx.apply(|ctx| unsafe {
5557
self.static_mem.copy_out(offset, data, ctx);
5658
});
@@ -61,11 +63,11 @@ impl Graph {
6163
where
6264
I: IntoIterator<Item = (&'a usize, &'a [T])>,
6365
{
64-
let start = self.topology.global_inputs().start;
66+
let start = self.offsets.topology.global_inputs().start;
6567
self.ctx.apply(|ctx| {
6668
let stream = ctx.stream();
6769
for (i, data) in data {
68-
let offset = self.edges[start + i].offset();
70+
let offset = self.offsets.edges[start + i].offset();
6971
unsafe { self.static_mem.copy_in_async(offset, data, &stream) };
7072
}
7173
});
@@ -76,102 +78,224 @@ impl Graph {
7678
where
7779
I: IntoIterator<Item = (&'a usize, &'a mut [T])>,
7880
{
79-
let global_output = self.topology.global_outputs();
81+
let global_output = self.offsets.topology.global_outputs();
8082
self.ctx.apply(|ctx| {
8183
let stream = ctx.stream();
8284
for (i, data) in data {
83-
let offset = self.edges[global_output[*i] as usize].offset();
85+
let offset = self.offsets.edges[global_output[*i] as usize].offset();
8486
unsafe { self.static_mem.copy_out_async(offset, data, &stream) };
8587
}
8688
});
8789
}
8890
}
8991

92+
#[allow(non_camel_case_types)]
93+
type urc = u16;
94+
const STATIC: urc = urc::MAX;
95+
const CUDA_ALIGN: usize = 256;
96+
9097
impl ContextGuard<'_> {
9198
pub fn runtime_graph(&self, src: &computation::Graph) -> Graph {
92-
let src = &src.0;
93-
94-
let mut static_mem = flat::RealtimeCalculator::default();
99+
let mut static_mem: flat::RealtimeCalculator = flat::RealtimeCalculator::default();
95100
let mut stack = unidir::RealtimeCalculator::default();
96101

97-
let mut edges = vec![MemOffset::INVALID; src.edges.len()];
102+
let mut nodes = vec![usize::MAX; src.0.nodes.len()];
103+
let mut edges = vec![MemOffset::INVALID; src.0.edges.len()];
98104
let mut local_edges = BTreeSet::<usize>::new();
99105

100-
#[allow(non_camel_case_types)]
101-
type urc = u16;
102-
const STATIC: urc = urc::MAX;
103-
let mut edge_rc = vec![0 as urc; src.edges.len()];
104-
for edge_idx in src.topology.connections() {
106+
// 计算边引用计数
107+
let mut edge_rc = vec![0 as urc; src.0.edges.len()];
108+
for edge_idx in src.0.topology.connections() {
105109
edge_rc[edge_idx] += 1;
106110
}
107111

108-
src.topology
112+
// 为输入输出分配静态存储区
113+
src.0
114+
.topology
109115
.global_inputs()
110-
.chain(src.topology.global_outputs())
116+
.chain(src.0.topology.global_outputs())
111117
.for_each(|edge_idx| {
112-
edge_rc[edge_idx] = STATIC;
113-
edges[edge_idx] = MemOffset::from_static(
114-
// 全图输入输出分配在静态存储区
115-
static_mem.alloc(cuda_layout(&src.edges[edge_idx])).start,
116-
);
118+
alloc_static(src, edge_idx, &mut edges, &mut edge_rc, &mut static_mem)
117119
});
118120

119-
let mut graph = driver::Graph::new();
121+
// 计算工作空间需求,分配栈空间
122+
let mut builders = Vec::<Box<dyn GraphBuilder>>::with_capacity(src.0.nodes.len());
123+
let mut resources = Resources::default();
124+
for (node_idx, inputs, outputs) in &src.0.topology {
125+
let (op, _) = &src.0.nodes[node_idx];
126+
let builder = op.builder(&mut resources, self);
127+
let workspace = builder.worksapce().align_to(CUDA_ALIGN).unwrap();
128+
builders.push(builder);
120129

121-
for (node_idx, inputs, outputs) in &src.topology {
122-
let (op, _) = &src.nodes[node_idx];
123-
// TODO 分配栈空间,构造计算节点
130+
// alloc for outputs
131+
for edge_idx in outputs.clone() {
132+
if edge_rc[edge_idx] != STATIC {
133+
alloc_stack(src, edge_idx, &mut edges, &mut stack);
134+
}
135+
}
136+
// alloc for workspaces
137+
alloc_workspace(workspace, node_idx, &mut nodes, &mut stack);
138+
// free for temp outputs
139+
for edge_idx in outputs {
140+
if edge_rc[edge_idx] == 0 {
141+
free_stack(src, edge_idx, &edges[edge_idx], &mut stack);
142+
}
143+
}
144+
// free for inputs or alloc for local static inputs
145+
for edge_idx in inputs {
146+
let offset = edges[edge_idx];
147+
if offset == MemOffset::INVALID {
148+
local_edges.insert(edge_idx);
149+
alloc_static(src, edge_idx, &mut edges, &mut edge_rc, &mut static_mem);
150+
} else {
151+
let rc = &mut edge_rc[edge_idx];
152+
debug_assert_ne!(*rc, 0);
153+
*rc -= 1;
154+
if *rc == 0 {
155+
free_stack(src, edge_idx, &offset, &mut stack);
156+
}
157+
}
158+
}
124159
}
125160

126-
let static_mem = {
161+
// 实际分配显存空间
162+
let resources = resources;
163+
let edges = edges;
164+
let (static_mem, stack) = {
127165
let stream = self.stream();
128-
let mut static_mem = self.malloc(static_mem.peak());
166+
167+
let mut static_mem = stream.malloc(static_mem.peak());
168+
let stack = stream.malloc(stack.peak());
169+
129170
for edge_idx in local_edges {
130171
let offset = edges[edge_idx].offset();
131-
let tensor = &src.edges[edge_idx].0;
172+
let tensor = &src.0.edges[edge_idx].0;
132173
let ptr = tensor.blob.as_ref().unwrap().get().cast::<u8>();
133174
let len = tensor.blob_mem_layout().size();
134175
unsafe {
135176
let data = std::slice::from_raw_parts(ptr, len);
136177
static_mem.copy_in_async(offset, data, &stream);
137178
}
138179
}
139-
static_mem
180+
181+
(static_mem, stack)
140182
};
141183

184+
let mut graph = driver::Graph::new();
185+
for (node_idx, inputs, outputs) in &src.0.topology {
186+
// TODO 计算实际地址
187+
let mut temp = Vec::with_capacity(1 + inputs.len() + outputs.len());
188+
temp.extend(inputs.iter().map(|i| edges[*i as usize]).map(|offset| {
189+
if offset.is_static() {
190+
todo!()
191+
} else {
192+
todo!()
193+
}
194+
}));
195+
builders[node_idx].push_to(
196+
&mut graph,
197+
&resources,
198+
&temp[0],
199+
&temp[1..][..inputs.len()],
200+
&temp[1 + inputs.len()..],
201+
)
202+
}
203+
142204
Graph {
143205
ctx: self.clone_ctx(),
144-
graph: graph.instantiate(self),
145-
topology: src.topology.clone(),
146-
edges,
206+
executable: graph.instantiate(self),
207+
resources,
147208
static_mem,
148-
stack: self.malloc(stack.peak()),
209+
stack,
210+
offsets: graph_topo::Graph {
211+
topology: src.0.topology.clone(),
212+
nodes,
213+
edges,
214+
},
149215
}
150216
}
151217
}
152218

153-
#[inline(always)]
154-
fn cuda_layout(edge: &(Tensor, String)) -> Layout {
155-
edge.0.blob_mem_layout().align_to(256).unwrap()
219+
fn alloc_workspace(
220+
workspace: Layout,
221+
node_idx: usize,
222+
nodes: &mut [usize],
223+
stack: &mut unidir::RealtimeCalculator,
224+
) {
225+
let workspace = stack.alloc(workspace);
226+
nodes[node_idx] = workspace.start;
227+
stack.free(workspace);
228+
}
229+
230+
fn alloc_stack(
231+
src: &computation::Graph,
232+
edge_idx: usize,
233+
edges: &mut [MemOffset],
234+
calculator: &mut unidir::RealtimeCalculator,
235+
) {
236+
let layout = src.0.edges[edge_idx]
237+
.0
238+
.blob_mem_layout()
239+
.align_to(CUDA_ALIGN)
240+
.unwrap();
241+
let offset = calculator.alloc(layout).start;
242+
edges[edge_idx] = MemOffset::from_stack(offset);
243+
}
244+
245+
fn free_stack(
246+
src: &computation::Graph,
247+
edge_idx: usize,
248+
offset: &MemOffset,
249+
calculator: &mut unidir::RealtimeCalculator,
250+
) {
251+
let start = offset.offset();
252+
let len = src.0.edges[edge_idx].0.blob_mem_layout().size();
253+
calculator.free(start..start + len);
254+
}
255+
256+
fn alloc_static(
257+
src: &computation::Graph,
258+
edge_idx: usize,
259+
edges: &mut [MemOffset],
260+
edge_rc: &mut [urc],
261+
calculator: &mut flat::RealtimeCalculator,
262+
) {
263+
let layout = src.0.edges[edge_idx]
264+
.0
265+
.blob_mem_layout()
266+
.align_to(CUDA_ALIGN)
267+
.unwrap();
268+
let offset = calculator.alloc(layout).start;
269+
edges[edge_idx] = MemOffset::from_static(offset);
270+
edge_rc[edge_idx] = STATIC;
156271
}
157272

158273
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
159274
#[repr(transparent)]
160275
struct MemOffset(usize);
161276

162277
impl MemOffset {
163-
const INVALID: MemOffset = MemOffset(usize::MAX);
278+
const INVALID: Self = Self(usize::MAX);
164279
const BIT: usize = 1 << (usize::BITS - 1);
165280

166-
fn from_static(offset: usize) -> Self {
281+
#[inline]
282+
const fn from_static(offset: usize) -> Self {
283+
Self(offset)
284+
}
285+
286+
#[inline]
287+
const fn from_stack(offset: usize) -> Self {
167288
Self(offset | Self::BIT)
168289
}
169290

170-
fn is_static(self) -> bool {
171-
self.0 & Self::BIT != 0
291+
#[inline]
292+
const fn is_static(self) -> bool {
293+
self.0 & Self::BIT == 0
172294
}
173295

296+
#[inline]
174297
fn offset(self) -> usize {
298+
debug_assert_ne!(self, Self::INVALID);
175299
self.0 & !Self::BIT
176300
}
177301
}

0 commit comments

Comments
 (0)