1
- use crate :: driver:: { self , ContextGuard } ;
2
- use computation:: Tensor ;
3
- use graph_topo:: GraphTopo ;
1
+ use crate :: {
2
+ driver:: { self , ContextGuard } ,
3
+ kernel:: { GraphBuilder , GraphUser , Resources } ,
4
+ } ;
4
5
use stack_calculator:: { flat, unidir, RealtimeCalculator } ;
5
6
use std:: { alloc:: Layout , collections:: BTreeSet , sync:: Arc } ;
6
7
7
8
pub struct Graph {
8
9
ctx : Arc < driver:: Context > ,
9
- graph : driver:: ExecutableGraph ,
10
- topology : GraphTopo ,
11
- edges : Vec < MemOffset > ,
10
+ executable : driver:: ExecutableGraph ,
11
+ # [ allow ( unused ) ] // stay here to keep resource lifetime
12
+ resources : Resources ,
12
13
static_mem : driver:: DevicePtr ,
13
14
stack : driver:: DevicePtr ,
15
+ offsets : graph_topo:: Graph < usize , MemOffset > ,
14
16
}
15
17
16
18
impl Drop for Graph {
@@ -34,23 +36,23 @@ impl Graph {
34
36
pub fn run ( & self ) {
35
37
self . ctx . apply ( |ctx| {
36
38
let stream = ctx. stream ( ) ;
37
- unsafe { self . graph . launch_on ( & stream) }
39
+ unsafe { self . executable . launch_on ( & stream) }
38
40
} )
39
41
}
40
42
41
43
#[ inline]
42
44
pub fn copy_in_one < T > ( & mut self , i : usize , data : & [ T ] ) {
43
- let i = self . topology . global_inputs ( ) . nth ( i) . unwrap ( ) ;
44
- let offset = self . edges [ i] . offset ( ) ;
45
+ let i = self . offsets . topology . global_inputs ( ) . nth ( i) . unwrap ( ) ;
46
+ let offset = self . offsets . edges [ i] . offset ( ) ;
45
47
self . ctx . apply ( |ctx| unsafe {
46
48
self . static_mem . copy_in ( offset, data, ctx) ;
47
49
} ) ;
48
50
}
49
51
50
52
#[ inline]
51
53
pub fn copy_out_one < T > ( & mut self , i : usize , data : & mut [ T ] ) {
52
- let i = self . topology . global_outputs ( ) [ i] ;
53
- let offset = self . edges [ i as usize ] . offset ( ) ;
54
+ let i = self . offsets . topology . global_outputs ( ) [ i] ;
55
+ let offset = self . offsets . edges [ i as usize ] . offset ( ) ;
54
56
self . ctx . apply ( |ctx| unsafe {
55
57
self . static_mem . copy_out ( offset, data, ctx) ;
56
58
} ) ;
@@ -61,11 +63,11 @@ impl Graph {
61
63
where
62
64
I : IntoIterator < Item = ( & ' a usize , & ' a [ T ] ) > ,
63
65
{
64
- let start = self . topology . global_inputs ( ) . start ;
66
+ let start = self . offsets . topology . global_inputs ( ) . start ;
65
67
self . ctx . apply ( |ctx| {
66
68
let stream = ctx. stream ( ) ;
67
69
for ( i, data) in data {
68
- let offset = self . edges [ start + i] . offset ( ) ;
70
+ let offset = self . offsets . edges [ start + i] . offset ( ) ;
69
71
unsafe { self . static_mem . copy_in_async ( offset, data, & stream) } ;
70
72
}
71
73
} ) ;
@@ -76,102 +78,224 @@ impl Graph {
76
78
where
77
79
I : IntoIterator < Item = ( & ' a usize , & ' a mut [ T ] ) > ,
78
80
{
79
- let global_output = self . topology . global_outputs ( ) ;
81
+ let global_output = self . offsets . topology . global_outputs ( ) ;
80
82
self . ctx . apply ( |ctx| {
81
83
let stream = ctx. stream ( ) ;
82
84
for ( i, data) in data {
83
- let offset = self . edges [ global_output[ * i] as usize ] . offset ( ) ;
85
+ let offset = self . offsets . edges [ global_output[ * i] as usize ] . offset ( ) ;
84
86
unsafe { self . static_mem . copy_out_async ( offset, data, & stream) } ;
85
87
}
86
88
} ) ;
87
89
}
88
90
}
89
91
92
+ #[ allow( non_camel_case_types) ]
93
+ type urc = u16 ;
94
+ const STATIC : urc = urc:: MAX ;
95
+ const CUDA_ALIGN : usize = 256 ;
96
+
90
97
impl ContextGuard < ' _ > {
91
98
pub fn runtime_graph ( & self , src : & computation:: Graph ) -> Graph {
92
- let src = & src. 0 ;
93
-
94
- let mut static_mem = flat:: RealtimeCalculator :: default ( ) ;
99
+ let mut static_mem: flat:: RealtimeCalculator = flat:: RealtimeCalculator :: default ( ) ;
95
100
let mut stack = unidir:: RealtimeCalculator :: default ( ) ;
96
101
97
- let mut edges = vec ! [ MemOffset :: INVALID ; src. edges. len( ) ] ;
102
+ let mut nodes = vec ! [ usize :: MAX ; src. 0 . nodes. len( ) ] ;
103
+ let mut edges = vec ! [ MemOffset :: INVALID ; src. 0 . edges. len( ) ] ;
98
104
let mut local_edges = BTreeSet :: < usize > :: new ( ) ;
99
105
100
- #[ allow( non_camel_case_types) ]
101
- type urc = u16 ;
102
- const STATIC : urc = urc:: MAX ;
103
- let mut edge_rc = vec ! [ 0 as urc; src. edges. len( ) ] ;
104
- for edge_idx in src. topology . connections ( ) {
106
+ // 计算边引用计数
107
+ let mut edge_rc = vec ! [ 0 as urc; src. 0 . edges. len( ) ] ;
108
+ for edge_idx in src. 0 . topology . connections ( ) {
105
109
edge_rc[ edge_idx] += 1 ;
106
110
}
107
111
108
- src. topology
112
+ // 为输入输出分配静态存储区
113
+ src. 0
114
+ . topology
109
115
. global_inputs ( )
110
- . chain ( src. topology . global_outputs ( ) )
116
+ . chain ( src. 0 . topology . global_outputs ( ) )
111
117
. for_each ( |edge_idx| {
112
- edge_rc[ edge_idx] = STATIC ;
113
- edges[ edge_idx] = MemOffset :: from_static (
114
- // 全图输入输出分配在静态存储区
115
- static_mem. alloc ( cuda_layout ( & src. edges [ edge_idx] ) ) . start ,
116
- ) ;
118
+ alloc_static ( src, edge_idx, & mut edges, & mut edge_rc, & mut static_mem)
117
119
} ) ;
118
120
119
- let mut graph = driver:: Graph :: new ( ) ;
121
+ // 计算工作空间需求,分配栈空间
122
+ let mut builders = Vec :: < Box < dyn GraphBuilder > > :: with_capacity ( src. 0 . nodes . len ( ) ) ;
123
+ let mut resources = Resources :: default ( ) ;
124
+ for ( node_idx, inputs, outputs) in & src. 0 . topology {
125
+ let ( op, _) = & src. 0 . nodes [ node_idx] ;
126
+ let builder = op. builder ( & mut resources, self ) ;
127
+ let workspace = builder. worksapce ( ) . align_to ( CUDA_ALIGN ) . unwrap ( ) ;
128
+ builders. push ( builder) ;
120
129
121
- for ( node_idx, inputs, outputs) in & src. topology {
122
- let ( op, _) = & src. nodes [ node_idx] ;
123
- // TODO 分配栈空间,构造计算节点
130
+ // alloc for outputs
131
+ for edge_idx in outputs. clone ( ) {
132
+ if edge_rc[ edge_idx] != STATIC {
133
+ alloc_stack ( src, edge_idx, & mut edges, & mut stack) ;
134
+ }
135
+ }
136
+ // alloc for workspaces
137
+ alloc_workspace ( workspace, node_idx, & mut nodes, & mut stack) ;
138
+ // free for temp outputs
139
+ for edge_idx in outputs {
140
+ if edge_rc[ edge_idx] == 0 {
141
+ free_stack ( src, edge_idx, & edges[ edge_idx] , & mut stack) ;
142
+ }
143
+ }
144
+ // free for inputs or alloc for local static inputs
145
+ for edge_idx in inputs {
146
+ let offset = edges[ edge_idx] ;
147
+ if offset == MemOffset :: INVALID {
148
+ local_edges. insert ( edge_idx) ;
149
+ alloc_static ( src, edge_idx, & mut edges, & mut edge_rc, & mut static_mem) ;
150
+ } else {
151
+ let rc = & mut edge_rc[ edge_idx] ;
152
+ debug_assert_ne ! ( * rc, 0 ) ;
153
+ * rc -= 1 ;
154
+ if * rc == 0 {
155
+ free_stack ( src, edge_idx, & offset, & mut stack) ;
156
+ }
157
+ }
158
+ }
124
159
}
125
160
126
- let static_mem = {
161
+ // 实际分配显存空间
162
+ let resources = resources;
163
+ let edges = edges;
164
+ let ( static_mem, stack) = {
127
165
let stream = self . stream ( ) ;
128
- let mut static_mem = self . malloc ( static_mem. peak ( ) ) ;
166
+
167
+ let mut static_mem = stream. malloc ( static_mem. peak ( ) ) ;
168
+ let stack = stream. malloc ( stack. peak ( ) ) ;
169
+
129
170
for edge_idx in local_edges {
130
171
let offset = edges[ edge_idx] . offset ( ) ;
131
- let tensor = & src. edges [ edge_idx] . 0 ;
172
+ let tensor = & src. 0 . edges [ edge_idx] . 0 ;
132
173
let ptr = tensor. blob . as_ref ( ) . unwrap ( ) . get ( ) . cast :: < u8 > ( ) ;
133
174
let len = tensor. blob_mem_layout ( ) . size ( ) ;
134
175
unsafe {
135
176
let data = std:: slice:: from_raw_parts ( ptr, len) ;
136
177
static_mem. copy_in_async ( offset, data, & stream) ;
137
178
}
138
179
}
139
- static_mem
180
+
181
+ ( static_mem, stack)
140
182
} ;
141
183
184
+ let mut graph = driver:: Graph :: new ( ) ;
185
+ for ( node_idx, inputs, outputs) in & src. 0 . topology {
186
+ // TODO 计算实际地址
187
+ let mut temp = Vec :: with_capacity ( 1 + inputs. len ( ) + outputs. len ( ) ) ;
188
+ temp. extend ( inputs. iter ( ) . map ( |i| edges[ * i as usize ] ) . map ( |offset| {
189
+ if offset. is_static ( ) {
190
+ todo ! ( )
191
+ } else {
192
+ todo ! ( )
193
+ }
194
+ } ) ) ;
195
+ builders[ node_idx] . push_to (
196
+ & mut graph,
197
+ & resources,
198
+ & temp[ 0 ] ,
199
+ & temp[ 1 ..] [ ..inputs. len ( ) ] ,
200
+ & temp[ 1 + inputs. len ( ) ..] ,
201
+ )
202
+ }
203
+
142
204
Graph {
143
205
ctx : self . clone_ctx ( ) ,
144
- graph : graph. instantiate ( self ) ,
145
- topology : src. topology . clone ( ) ,
146
- edges,
206
+ executable : graph. instantiate ( self ) ,
207
+ resources,
147
208
static_mem,
148
- stack : self . malloc ( stack. peak ( ) ) ,
209
+ stack,
210
+ offsets : graph_topo:: Graph {
211
+ topology : src. 0 . topology . clone ( ) ,
212
+ nodes,
213
+ edges,
214
+ } ,
149
215
}
150
216
}
151
217
}
152
218
153
- #[ inline( always) ]
154
- fn cuda_layout ( edge : & ( Tensor , String ) ) -> Layout {
155
- edge. 0 . blob_mem_layout ( ) . align_to ( 256 ) . unwrap ( )
219
+ fn alloc_workspace (
220
+ workspace : Layout ,
221
+ node_idx : usize ,
222
+ nodes : & mut [ usize ] ,
223
+ stack : & mut unidir:: RealtimeCalculator ,
224
+ ) {
225
+ let workspace = stack. alloc ( workspace) ;
226
+ nodes[ node_idx] = workspace. start ;
227
+ stack. free ( workspace) ;
228
+ }
229
+
230
+ fn alloc_stack (
231
+ src : & computation:: Graph ,
232
+ edge_idx : usize ,
233
+ edges : & mut [ MemOffset ] ,
234
+ calculator : & mut unidir:: RealtimeCalculator ,
235
+ ) {
236
+ let layout = src. 0 . edges [ edge_idx]
237
+ . 0
238
+ . blob_mem_layout ( )
239
+ . align_to ( CUDA_ALIGN )
240
+ . unwrap ( ) ;
241
+ let offset = calculator. alloc ( layout) . start ;
242
+ edges[ edge_idx] = MemOffset :: from_stack ( offset) ;
243
+ }
244
+
245
+ fn free_stack (
246
+ src : & computation:: Graph ,
247
+ edge_idx : usize ,
248
+ offset : & MemOffset ,
249
+ calculator : & mut unidir:: RealtimeCalculator ,
250
+ ) {
251
+ let start = offset. offset ( ) ;
252
+ let len = src. 0 . edges [ edge_idx] . 0 . blob_mem_layout ( ) . size ( ) ;
253
+ calculator. free ( start..start + len) ;
254
+ }
255
+
256
+ fn alloc_static (
257
+ src : & computation:: Graph ,
258
+ edge_idx : usize ,
259
+ edges : & mut [ MemOffset ] ,
260
+ edge_rc : & mut [ urc ] ,
261
+ calculator : & mut flat:: RealtimeCalculator ,
262
+ ) {
263
+ let layout = src. 0 . edges [ edge_idx]
264
+ . 0
265
+ . blob_mem_layout ( )
266
+ . align_to ( CUDA_ALIGN )
267
+ . unwrap ( ) ;
268
+ let offset = calculator. alloc ( layout) . start ;
269
+ edges[ edge_idx] = MemOffset :: from_static ( offset) ;
270
+ edge_rc[ edge_idx] = STATIC ;
156
271
}
157
272
158
273
#[ derive( Clone , Copy , PartialEq , Eq , Debug ) ]
159
274
#[ repr( transparent) ]
160
275
struct MemOffset ( usize ) ;
161
276
162
277
impl MemOffset {
163
- const INVALID : MemOffset = MemOffset ( usize:: MAX ) ;
278
+ const INVALID : Self = Self ( usize:: MAX ) ;
164
279
const BIT : usize = 1 << ( usize:: BITS - 1 ) ;
165
280
166
- fn from_static ( offset : usize ) -> Self {
281
+ #[ inline]
282
+ const fn from_static ( offset : usize ) -> Self {
283
+ Self ( offset)
284
+ }
285
+
286
+ #[ inline]
287
+ const fn from_stack ( offset : usize ) -> Self {
167
288
Self ( offset | Self :: BIT )
168
289
}
169
290
170
- fn is_static ( self ) -> bool {
171
- self . 0 & Self :: BIT != 0
291
+ #[ inline]
292
+ const fn is_static ( self ) -> bool {
293
+ self . 0 & Self :: BIT == 0
172
294
}
173
295
296
+ #[ inline]
174
297
fn offset ( self ) -> usize {
298
+ debug_assert_ne ! ( self , Self :: INVALID ) ;
175
299
self . 0 & !Self :: BIT
176
300
}
177
301
}
0 commit comments