Skip to content
This repository was archived by the owner on Nov 18, 2020. It is now read-only.

Commit 35e9965

Browse files
committed
Add OpenCL runtime support
Abstract runtime functionality by HSA (TODO: OCL) Use Requires to load OpenCL bindings Allow choosing runtime via environment variables
1 parent 8070d83 commit 35e9965

13 files changed

+152
-60
lines changed

Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ HSARuntime = "2c364e2c-59fb-59c3-96f3-194112e690e0"
1010
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1111
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1212
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
13+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1314
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
1415

1516
[compat]

src/AMDGPUnative.jl

+8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ using Adapt
77
using TimerOutputs
88
using DataStructures
99
using Libdl
10+
using Requires
11+
12+
@enum DeviceRuntime HSA OCL
13+
const RUNTIME = Ref{DeviceRuntime}(HSA)
14+
if get(ENV, "AMDGPUNATIVE_OPENCL", "") != ""
15+
RUNTIME[] = OCL
16+
end
1017

1118
const configured = HSARuntime.configured
1219

@@ -29,6 +36,7 @@ include("reflection.jl")
2936

3037
function __init__()
3138
check_deps()
39+
@require OpenCL="08131aa3-fb12-5dee-8b74-c09406e224a2" include("opencl.jl")
3240
__init_compiler__()
3341
end
3442

src/compiler/common.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ struct CompilerJob
44
# core invocation
55
f::Base.Callable
66
tt::DataType
7-
agent::HSAAgent
7+
device::RuntimeDevice
88
kernel::Bool
99

1010
# optional properties
@@ -14,10 +14,10 @@ struct CompilerJob
1414
maxregs::Union{Nothing,Integer}
1515
name::Union{Nothing,String}
1616

17-
CompilerJob(f, tt, agent, kernel; name=nothing,
17+
CompilerJob(f, tt, device, kernel; name=nothing,
1818
minthreads=nothing, maxthreads=nothing,
1919
blocks_per_sm=nothing, maxregs=nothing) =
20-
new(f, tt, agent, kernel, minthreads, maxthreads, blocks_per_sm,
20+
new(f, tt, device, kernel, minthreads, maxthreads, blocks_per_sm,
2121
maxregs, name)
2222
end
2323

src/compiler/driver.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
const compile_hook = Ref{Union{Nothing,Function}}(nothing)
55

66
"""
7-
compile(target::Symbol, agent::HSAAgent, f, tt, kernel=true;
7+
compile(target::Symbol, device::RuntimeDevice, f, tt, kernel=true;
88
libraries=true, optimize=true, strip=false, strict=true, ...)
9-
Compile a function `f` invoked with types `tt` for agent `agent` to one of the
10-
following formats as specified by the `target` argument: `:julia` for Julia
11-
IR, `:llvm` for LLVM IR, `:gcn` for GCN assembly, and `:roc` for linked
9+
Compile a function `f` invoked with types `tt` for device `device` to one of
10+
the following formats as specified by the `target` argument: `:julia` for
11+
Julia IR, `:llvm` for LLVM IR, `:gcn` for GCN assembly, and `:roc` for linked
1212
objects. If the `kernel` flag is set, specialized code generation and
1313
optimization for kernel functions is enabled.
1414
The following keyword arguments are supported:
@@ -18,11 +18,11 @@ The following keyword arguments are supported:
1818
- `strict`: perform code validation either as early or as late as possible
1919
Other keyword arguments can be found in the documentation of [`rocfunction`](@ref).
2020
"""
21-
compile(target::Symbol, agent::HSAAgent, @nospecialize(f::Core.Function),
21+
compile(target::Symbol, device::RuntimeDevice, @nospecialize(f::Core.Function),
2222
@nospecialize(tt), kernel::Bool=true; libraries::Bool=true,
2323
optimize::Bool=true, strip::Bool=false, strict::Bool=true, kwargs...) =
2424

25-
compile(target, CompilerJob(f, tt, agent, kernel; kwargs...);
25+
compile(target, CompilerJob(f, tt, device, kernel; kwargs...);
2626
libraries=libraries, optimize=optimize, strip=strip,
2727
strict=strict)
2828

@@ -88,7 +88,7 @@ function codegen(target::Symbol, job::CompilerJob; libraries::Bool=true,
8888
# always preload the runtime, and do so early; it cannot be part of any timing block
8989
# because it recurses into the compiler
9090
if libraries
91-
runtime = load_runtime(job.agent)
91+
runtime = load_runtime(job.device)
9292
runtime_fns = LLVM.name.(defs(runtime))
9393
end
9494

@@ -105,7 +105,7 @@ function codegen(target::Symbol, job::CompilerJob; libraries::Bool=true,
105105
end
106106
=#
107107
# FIXME: Load this only when needed
108-
device_libs = load_device_libs(job.agent)
108+
device_libs = load_device_libs(job.device)
109109
for lib in device_libs
110110
if need_library(ir, lib)
111111
@timeit to[] "device library" link_device_lib!(job, ir, lib)

src/compiler/mcgen.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# machine code generation
22

3-
function machine(agent::HSAAgent, triple::String)
3+
function machine(device::RuntimeDevice, triple::String)
44
InitializeAMDGPUTarget()
55
InitializeAMDGPUTargetInfo()
66
t = Target(triple)
77

88
InitializeAMDGPUTargetMC()
9-
cpu = get_first_isa(agent) # TODO: Make this configurable
9+
cpu = default_isa(device) # TODO: Make this configurable
1010
feat = ""
1111
tm = TargetMachine(t, triple, cpu, feat)
1212
asm_verbosity!(tm, true)
@@ -80,7 +80,7 @@ end
8080

8181
function mcgen(job::CompilerJob, mod::LLVM.Module, f::LLVM.Function;
8282
output_format=LLVM.API.LLVMObjectFile)
83-
tm = machine(job.agent, triple(mod))
83+
tm = machine(job.device, triple(mod))
8484

8585
InitializeAMDGPUAsmPrinter()
8686
return String(emit(tm, mod, output_format))

src/compiler/optim.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# LLVM IR optimization
22

33
function optimize!(job::CompilerJob, mod::LLVM.Module, entry::LLVM.Function)
4-
tm = AMDGPUnative.machine(job.agent, triple(mod))
4+
tm = AMDGPUnative.machine(job.device, triple(mod))
55

66
if job.kernel
77
entry = promote_kernel!(job, mod, entry)

src/compiler/rtlib.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ const libcache = Dict{String, LLVM.Module}()
3030

3131
# ROCm device library
3232

33-
function load_device_libs(agent)
33+
function load_device_libs(device)
3434
device_libs_path === nothing && return
3535

36-
isa_short = replace(get_first_isa(agent), "gfx"=>"")
36+
isa_short = replace(default_isa(device), "gfx"=>"")
3737
device_libs = LLVM.Module[]
3838
bitcode_files = (
3939
"hc.amdgcn.bc",
@@ -132,20 +132,20 @@ end
132132

133133
## functionality to build the runtime library
134134

135-
function emit_function!(mod, agent, f, types, name)
135+
function emit_function!(mod, device, f, types, name)
136136
tt = Base.to_tuple_type(types)
137-
new_mod, entry = codegen(:llvm, CompilerJob(f, tt, agent, #=kernel=# false);
137+
new_mod, entry = codegen(:llvm, CompilerJob(f, tt, device, #=kernel=# false);
138138
libraries=false, strict=false)
139139
LLVM.name!(entry, name)
140140
link!(mod, new_mod)
141141
end
142142

143-
function build_runtime(agent)
143+
function build_runtime(device)
144144
mod = LLVM.Module("AMDGPUnative run-time library", JuliaContext())
145145

146146
for method in values(Runtime.methods)
147147
try
148-
emit_function!(mod, agent, method.def, method.types, method.llvm_name)
148+
emit_function!(mod, device, method.def, method.types, method.llvm_name)
149149
catch err
150150
@warn method
151151
end
@@ -154,8 +154,8 @@ function build_runtime(agent)
154154
mod
155155
end
156156

157-
function load_runtime(agent::HSAAgent)
158-
isa = get_first_isa(agent)
157+
function load_runtime(device::RuntimeDevice)
158+
isa = default_isa(device)
159159
name = "amdgpunative.$isa.bc"
160160
path = joinpath(@__DIR__, "..", "..", "deps", "runtime", name)
161161
mkpath(dirname(path))
@@ -167,7 +167,7 @@ function load_runtime(agent::HSAAgent)
167167
end
168168
else
169169
@info "Building the AMDGPUnative run-time library for your $isa device, this might take a while..."
170-
lib = build_runtime(agent)
170+
lib = build_runtime(device)
171171
open(path, "w") do io
172172
write(io, lib)
173173
end

src/execution.jl

+29-11
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33
export @roc, rocconvert, rocfunction
44

55
struct Kernel{F,TT}
6-
agent::HSAAgent
6+
device::RuntimeDevice
77
mod::ROCModule
88
fun::ROCFunction
99
end
1010

1111
# `split_kwargs()` segregates keyword arguments passed to `@roc` into those
1212
# affecting the compiler, kernel execution, or both.
1313
function split_kwargs(kwargs)
14-
compiler_kws = [:agent, :queue, :name]
15-
call_kws = [:groupsize, :gridsize, :agent, :queue]
14+
# TODO: Alias groupsize and gridsize as threads and blocks, respectively
15+
compiler_kws = [:device, :agent, :queue, :name]
16+
call_kws = [:groupsize, :gridsize, :device, :agent, :queue]
1617
compiler_kwargs = []
1718
call_kwargs = []
1819
for kwarg in kwargs
@@ -60,6 +61,23 @@ function assign_args!(code, args)
6061
return vars, var_exprs
6162
end
6263

64+
function extract_device(;device=nothing, agent=nothing, kwargs...)
65+
if device !== nothing
66+
return device
67+
elseif agent !== nothing
68+
return agent
69+
else
70+
return default_device()
71+
end
72+
end
73+
function extract_queue(device; queue=nothing, kwargs...)
74+
if queue !== nothing
75+
return queue
76+
else
77+
return default_queue(device)
78+
end
79+
end
80+
6381
# fast lookup of global world age
6482
world_age() = ccall(:jl_get_tls_world_age, UInt, ())
6583

@@ -125,10 +143,10 @@ macro roc(ex...)
125143
GC.@preserve $(vars...) begin
126144
local kernel_args = map(rocconvert, ($(var_exprs...),))
127145
local kernel_tt = Tuple{Core.Typeof.(kernel_args)...}
128-
local agent = get_default_agent()
129-
local kernel = rocfunction(agent, $(esc(f)), kernel_tt;
146+
local device = extract_device(; $(esc(call_kwargs)...))
147+
local kernel = rocfunction(device, $(esc(f)), kernel_tt;
130148
$(map(esc, compiler_kwargs)...))
131-
local queue = get_default_queue(agent)
149+
local queue = extract_queue(device; $(esc(call_kwargs)...))
132150
local signal = HSASignal()
133151
kernel(queue, signal, kernel_args...; $(map(esc, call_kwargs)...))
134152
wait(signal)
@@ -188,7 +206,7 @@ The output of this function is automatically cached, i.e. you can simply call
188206
generated automatically, when the function changes, or when different types or
189207
keyword arguments are provided.
190208
"""
191-
@generated function rocfunction(agent::HSAAgent, f::Core.Function, tt::Type=Tuple{}; name=nothing, kwargs...)
209+
@generated function rocfunction(device::RuntimeDevice, f::Core.Function, tt::Type=Tuple{}; name=nothing, kwargs...)
192210
tt = Base.to_tuple_type(tt.parameters[1])
193211
sig = Base.signature_type(f, tt)
194212
t = Tuple(tt.parameters)
@@ -217,8 +235,8 @@ keyword arguments are provided.
217235

218236
# compile the function
219237
if !haskey(compilecache, key)
220-
fun, mod = compile(:roc, agent, f, tt; name=name, kwargs...)
221-
kernel = Kernel{f,tt}(agent, mod, fun)
238+
fun, mod = compile(:roc, device, f, tt; name=name, kwargs...)
239+
kernel = Kernel{f,tt}(device, mod, fun)
222240
compilecache[key] = kernel
223241
end
224242

@@ -227,9 +245,9 @@ keyword arguments are provided.
227245
end
228246

229247
rocfunction(f::Core.Function, tt::Type=Tuple{}; kwargs...) =
230-
rocfunction(get_default_agent(), f, tt; kwargs...)
248+
rocfunction(default_device(), f, tt; kwargs...)
231249

232-
@generated function call(kernel::Kernel{F,TT}, queue::HSAQueue,
250+
@generated function call(kernel::Kernel{F,TT}, queue::RuntimeQueue,
233251
signal::HSASignal, args...; call_kwargs...) where {F,TT}
234252

235253
sig = Base.signature_type(F, TT)

src/execution_utils.jl

+7-17
Original file line numberDiff line numberDiff line change
@@ -101,23 +101,13 @@ end
101101
GC.@preserve $(arg_refs...) begin
102102
kernelParams = [$(arg_ptrs...)]
103103

104-
# link with ld.lld
105-
ld_path = HSARuntime.ld_lld_path
106-
@assert ld_path != "" "ld.lld was not found; cannot link kernel"
107-
# TODO: Do this more idiomatically
108-
io = open("/tmp/amdgpu-dump.o", "w")
109-
write(io, f.mod.data)
110-
close(io)
111-
run(`$ld_path -shared -o /tmp/amdgpu.exe /tmp/amdgpu-dump.o`)
112-
io = open("/tmp/amdgpu.exe", "r")
113-
data = read(io)
114-
close(io)
115-
116-
# generate executable and kernel instance
117-
exe = HSAExecutable(queue.agent, data, f.entry)
118-
kern = HSAKernelInstance(queue.agent, exe, f.entry, args)
119-
HSARuntime.launch!(queue, kern, signal;
120-
workgroup_size=groupsize, grid_size=gridsize)
104+
# create executable and kernel instance
105+
exe = create_executable(get_device(queue), data, f.entry)
106+
kern = create_kernel(get_device(queue), exe, f.entry, args)
107+
108+
# launch kernel
109+
launch_kernel(queue, kern, signal;
110+
workgroup_size=groupsize, grid_size=gridsize)
121111
end
122112
end).args)
123113

src/opencl.jl

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# OpenCL runtime interface to AMDGPUnative
2+
3+
include("opencl/args.jl")

src/opencl/args.jl

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# OpenCL argument utilities
2+
3+
# Argument accessors
4+
5+
# __agocl_global_offset_x - OpenCL Global Offset X
6+
# __agocl_global_offset_y - OpenCL Global Offset Y
7+
# __agocl_global_offset_z - OpenCL Global Offset Z
8+
# __agocl_printf_addr - OpenCL address of printf buffer
9+
# __agocl_queue_addr - OpenCL address of virtual queue used by enqueue_kernel
10+
# __agocl_aqlwrap_addr - OpenCL address of AqlWrap struct used by enqueue_kernel
11+
# __agocl_multigrid - Pointer argument used for Multi-grid synchronization
12+

src/reflection.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@ using InteractiveUtils
77
# and/or to support generating otherwise invalid code (e.g. with missing symbols).
88

99
"""
10-
code_llvm([io], f, types; optimize=true, agent::HSAAgent=get_default_agent(), kernel=false,
10+
code_llvm([io], f, types; optimize=true, device::RuntimeDevice=default_device(), kernel=false,
1111
optimize=true, raw=false, dump_module=false, strict=false)
1212
1313
Prints the device LLVM IR generated for the method matching the given generic function and
1414
type signature to `io` which defaults to `stdout`.
1515
1616
The following keyword arguments are supported:
1717
18-
- `agent`: which device to generate code for
18+
- `device`: which device to generate code for
1919
- `kernel`: treat the function as an entry-point kernel
2020
- `optimize`: determines if the code is optimized, which includes kernel-specific
2121
optimizations if `kernel` is true
@@ -26,11 +26,11 @@ The following keyword arguments are supported:
2626
See also: [`@device_code_llvm`](@ref), [`InteractiveUtils.code_llvm`](@ref)
2727
"""
2828
function code_llvm(io::IO, @nospecialize(func), @nospecialize(types);
29-
optimize::Bool=true, agent::HSAAgent=get_default_agent(),
29+
optimize::Bool=true, device::RuntimeDevice=default_device(),
3030
dump_module::Bool=false, raw::Bool=false,
3131
kernel::Bool=false, strict::Bool=false, kwargs...)
3232
tt = Base.to_tuple_type(types)
33-
job = CompilerJob(func, tt, agent, kernel; kwargs...)
33+
job = CompilerJob(func, tt, device, kernel; kwargs...)
3434
code_llvm(io, job; optimize=optimize,
3535
raw=raw, dump_module=dump_module, strict=strict)
3636
end
@@ -49,25 +49,25 @@ code_llvm(@nospecialize(func), @nospecialize(types); kwargs...) =
4949
code_llvm(stdout, func, types; kwargs...)
5050

5151
"""
52-
code_gcn([io], f, types; agent::HSAAgent=get_default_agent(), kernel=false, raw=false, strict=false)
52+
code_gcn([io], f, types; device::RuntimeDevice=default_device(), kernel=false, raw=false, strict=false)
5353
5454
Prints the GCN assembly generated for the method matching the given generic function and
5555
type signature to `io` which defaults to `stdout`.
5656
5757
The following keyword arguments are supported:
5858
59-
- `agent`: which device to generate code for
59+
- `device`: which device to generate code for
6060
- `kernel`: treat the function as an entry-point kernel
6161
- `raw`: return the raw code including all metadata
6262
- `strict`: verify generate code as early as possible
6363
6464
See also: [`@device_code_gcn`](@ref)
6565
"""
6666
function code_gcn(io::IO, @nospecialize(func), @nospecialize(types);
67-
agent::HSAAgent=get_default_agent(), kernel::Bool=false,
67+
device::RuntimeDevice=default_device(), kernel::Bool=false,
6868
raw::Bool=false, strict::Bool=false, kwargs...)
6969
tt = Base.to_tuple_type(types)
70-
job = CompilerJob(func, tt, agent, kernel; kwargs...)
70+
job = CompilerJob(func, tt, device, kernel; kwargs...)
7171
code_gcn(io, job; raw=raw, strict=strict)
7272
end
7373

0 commit comments

Comments
 (0)