3
3
export @roc , rocconvert, rocfunction
4
4
5
5
struct Kernel{F,TT}
6
- agent :: HSAAgent
6
+ device :: RuntimeDevice
7
7
mod:: ROCModule
8
8
fun:: ROCFunction
9
9
end
10
10
11
11
# `split_kwargs()` segregates keyword arguments passed to `@roc` into those
12
12
# affecting the compiler, kernel execution, or both.
13
13
function split_kwargs (kwargs)
14
- compiler_kws = [:agent , :queue , :name ]
15
- call_kws = [:groupsize , :gridsize , :agent , :queue ]
14
+ # TODO : Alias groupsize and gridsize as threads and blocks, respectively
15
+ compiler_kws = [:device , :agent , :queue , :name ]
16
+ call_kws = [:groupsize , :gridsize , :device , :agent , :queue ]
16
17
compiler_kwargs = []
17
18
call_kwargs = []
18
19
for kwarg in kwargs
@@ -60,6 +61,23 @@ function assign_args!(code, args)
60
61
return vars, var_exprs
61
62
end
62
63
64
+ function extract_device (;device= nothing , agent= nothing , kwargs... )
65
+ if device != = nothing
66
+ return device
67
+ elseif agent != = nothing
68
+ return agent
69
+ else
70
+ return default_device ()
71
+ end
72
+ end
73
+ function extract_queue (device; queue= nothing , kwargs... )
74
+ if queue != = nothing
75
+ return queue
76
+ else
77
+ return default_queue (device)
78
+ end
79
+ end
80
+
63
81
# fast lookup of global world age
64
82
world_age () = ccall (:jl_get_tls_world_age , UInt, ())
65
83
@@ -125,10 +143,10 @@ macro roc(ex...)
125
143
GC. @preserve $ (vars... ) begin
126
144
local kernel_args = map (rocconvert, ($ (var_exprs... ),))
127
145
local kernel_tt = Tuple{Core. Typeof .(kernel_args)... }
128
- local agent = get_default_agent ( )
129
- local kernel = rocfunction (agent , $ (esc (f)), kernel_tt;
146
+ local device = extract_device (; $ ( esc (call_kwargs) ... ) )
147
+ local kernel = rocfunction (device , $ (esc (f)), kernel_tt;
130
148
$ (map (esc, compiler_kwargs)... ))
131
- local queue = get_default_queue (agent )
149
+ local queue = extract_queue (device; $ ( esc (call_kwargs) ... ) )
132
150
local signal = HSASignal ()
133
151
kernel (queue, signal, kernel_args... ; $ (map (esc, call_kwargs)... ))
134
152
wait (signal)
@@ -188,7 +206,7 @@ The output of this function is automatically cached, i.e. you can simply call
188
206
generated automatically, when the function changes, or when different types or
189
207
keyword arguments are provided.
190
208
"""
191
- @generated function rocfunction (agent :: HSAAgent , f:: Core.Function , tt:: Type = Tuple{}; name= nothing , kwargs... )
209
+ @generated function rocfunction (device :: RuntimeDevice , f:: Core.Function , tt:: Type = Tuple{}; name= nothing , kwargs... )
192
210
tt = Base. to_tuple_type (tt. parameters[1 ])
193
211
sig = Base. signature_type (f, tt)
194
212
t = Tuple (tt. parameters)
@@ -217,8 +235,8 @@ keyword arguments are provided.
217
235
218
236
# compile the function
219
237
if ! haskey (compilecache, key)
220
- fun, mod = compile (:roc , agent , f, tt; name= name, kwargs... )
221
- kernel = Kernel {f,tt} (agent , mod, fun)
238
+ fun, mod = compile (:roc , device , f, tt; name= name, kwargs... )
239
+ kernel = Kernel {f,tt} (device , mod, fun)
222
240
compilecache[key] = kernel
223
241
end
224
242
@@ -227,9 +245,9 @@ keyword arguments are provided.
227
245
end
228
246
229
247
rocfunction (f:: Core.Function , tt:: Type = Tuple{}; kwargs... ) =
230
- rocfunction (get_default_agent (), f, tt; kwargs... )
248
+ rocfunction (default_device (), f, tt; kwargs... )
231
249
232
- @generated function call (kernel:: Kernel{F,TT} , queue:: HSAQueue ,
250
+ @generated function call (kernel:: Kernel{F,TT} , queue:: RuntimeQueue ,
233
251
signal:: HSASignal , args... ; call_kwargs... ) where {F,TT}
234
252
235
253
sig = Base. signature_type (F, TT)
0 commit comments