Skip to content
This repository was archived by the owner on Nov 18, 2020. It is now read-only.

Add OpenCL runtime support #24

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ HSARuntime = "2c364e2c-59fb-59c3-96f3-194112e690e0"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[compat]
Expand All @@ -22,8 +23,10 @@ TimerOutputs = "0.5"
julia = "1"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OpenCL = "08131aa3-fb12-5dee-8b74-c09406e224a2"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["SpecialFunctions", "Test"]
test = ["LinearAlgebra", "OpenCL", "SpecialFunctions", "Test"]
10 changes: 10 additions & 0 deletions src/AMDGPUnative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ using Adapt
using TimerOutputs
using DataStructures
using Libdl
using Requires

@enum DeviceRuntime HSA OCL
const RUNTIME = Ref{DeviceRuntime}(HSA)
if get(ENV, "AMDGPUNATIVE_OPENCL", "") != ""
RUNTIME[] = OCL
end
include("runtime.jl")

const configured = HSARuntime.configured

Expand All @@ -21,6 +29,7 @@ include(joinpath("device", "pointer.jl"))
include(joinpath("device", "array.jl"))
include(joinpath("device", "gcn.jl"))
include(joinpath("device", "runtime.jl"))
include(joinpath("device", "llvm.jl"))

include("execution_utils.jl")
include("compiler.jl")
Expand All @@ -29,6 +38,7 @@ include("reflection.jl")

function __init__()
check_deps()
@require OpenCL="08131aa3-fb12-5dee-8b74-c09406e224a2" include("opencl.jl")
__init_compiler__()
end

Expand Down
6 changes: 3 additions & 3 deletions src/compiler/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ struct CompilerJob
# core invocation
f::Base.Callable
tt::DataType
agent::HSAAgent
device::RuntimeDevice
kernel::Bool

# optional properties
Expand All @@ -14,10 +14,10 @@ struct CompilerJob
maxregs::Union{Nothing,Integer}
name::Union{Nothing,String}

CompilerJob(f, tt, agent, kernel; name=nothing,
CompilerJob(f, tt, device, kernel; name=nothing,
minthreads=nothing, maxthreads=nothing,
blocks_per_sm=nothing, maxregs=nothing) =
new(f, tt, agent, kernel, minthreads, maxthreads, blocks_per_sm,
new(f, tt, device, kernel, minthreads, maxthreads, blocks_per_sm,
maxregs, name)
end

Expand Down
16 changes: 8 additions & 8 deletions src/compiler/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
const compile_hook = Ref{Union{Nothing,Function}}(nothing)

"""
compile(target::Symbol, agent::HSAAgent, f, tt, kernel=true;
compile(target::Symbol, device::RuntimeDevice, f, tt, kernel=true;
libraries=true, optimize=true, strip=false, strict=true, ...)
Compile a function `f` invoked with types `tt` for agent `agent` to one of the
following formats as specified by the `target` argument: `:julia` for Julia
IR, `:llvm` for LLVM IR, `:gcn` for GCN assembly, and `:roc` for linked
Compile a function `f` invoked with types `tt` for device `device` to one of
the following formats as specified by the `target` argument: `:julia` for
Julia IR, `:llvm` for LLVM IR, `:gcn` for GCN assembly, and `:roc` for linked
objects. If the `kernel` flag is set, specialized code generation and
optimization for kernel functions is enabled.
The following keyword arguments are supported:
Expand All @@ -18,11 +18,11 @@ The following keyword arguments are supported:
- `strict`: perform code validation either as early or as late as possible
Other keyword arguments can be found in the documentation of [`rocfunction`](@ref).
"""
compile(target::Symbol, agent::HSAAgent, @nospecialize(f::Core.Function),
compile(target::Symbol, device::RuntimeDevice, @nospecialize(f::Core.Function),
@nospecialize(tt), kernel::Bool=true; libraries::Bool=true,
optimize::Bool=true, strip::Bool=false, strict::Bool=true, kwargs...) =

compile(target, CompilerJob(f, tt, agent, kernel; kwargs...);
compile(target, CompilerJob(f, tt, device, kernel; kwargs...);
libraries=libraries, optimize=optimize, strip=strip,
strict=strict)

Expand Down Expand Up @@ -88,7 +88,7 @@ function codegen(target::Symbol, job::CompilerJob; libraries::Bool=true,
# always preload the runtime, and do so early; it cannot be part of any timing block
# because it recurses into the compiler
if libraries
runtime = load_runtime(job.agent)
runtime = load_runtime(job.device)
runtime_fns = LLVM.name.(defs(runtime))
end

Expand All @@ -105,7 +105,7 @@ function codegen(target::Symbol, job::CompilerJob; libraries::Bool=true,
end
=#
# FIXME: Load this only when needed
device_libs = load_device_libs(job.agent)
device_libs = load_device_libs(job.device)
for lib in device_libs
if need_library(ir, lib)
@timeit to[] "device library" link_device_lib!(job, ir, lib)
Expand Down
6 changes: 3 additions & 3 deletions src/compiler/mcgen.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# machine code generation

function machine(agent::HSAAgent, triple::String)
function machine(device::RuntimeDevice, triple::String)
InitializeAMDGPUTarget()
InitializeAMDGPUTargetInfo()
t = Target(triple)

InitializeAMDGPUTargetMC()
cpu = get_first_isa(agent) # TODO: Make this configurable
cpu = default_isa(device) # TODO: Make this configurable
feat = ""
tm = TargetMachine(t, triple, cpu, feat)
asm_verbosity!(tm, true)
Expand Down Expand Up @@ -80,7 +80,7 @@ end

function mcgen(job::CompilerJob, mod::LLVM.Module, f::LLVM.Function;
output_format=LLVM.API.LLVMObjectFile)
tm = machine(job.agent, triple(mod))
tm = machine(job.device, triple(mod))

InitializeAMDGPUAsmPrinter()
return String(emit(tm, mod, output_format))
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/optim.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# LLVM IR optimization

function optimize!(job::CompilerJob, mod::LLVM.Module, entry::LLVM.Function)
tm = AMDGPUnative.machine(job.agent, triple(mod))
tm = AMDGPUnative.machine(job.device, triple(mod))

if job.kernel
entry = promote_kernel!(job, mod, entry)
Expand Down
18 changes: 9 additions & 9 deletions src/compiler/rtlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ const libcache = Dict{String, LLVM.Module}()

# ROCm device library

function load_device_libs(agent)
function load_device_libs(device)
device_libs_path === nothing && return

isa_short = replace(get_first_isa(agent), "gfx"=>"")
isa_short = replace(default_isa(device), "gfx"=>"")
device_libs = LLVM.Module[]
bitcode_files = (
"hc.amdgcn.bc",
Expand Down Expand Up @@ -132,20 +132,20 @@ end

## functionality to build the runtime library

function emit_function!(mod, agent, f, types, name)
function emit_function!(mod, device, f, types, name)
tt = Base.to_tuple_type(types)
new_mod, entry = codegen(:llvm, CompilerJob(f, tt, agent, #=kernel=# false);
new_mod, entry = codegen(:llvm, CompilerJob(f, tt, device, #=kernel=# false);
libraries=false, strict=false)
LLVM.name!(entry, name)
link!(mod, new_mod)
end

function build_runtime(agent)
function build_runtime(device)
mod = LLVM.Module("AMDGPUnative run-time library", JuliaContext())

for method in values(Runtime.methods)
try
emit_function!(mod, agent, method.def, method.types, method.llvm_name)
emit_function!(mod, device, method.def, method.types, method.llvm_name)
catch err
@warn method
end
Expand All @@ -154,8 +154,8 @@ function build_runtime(agent)
mod
end

function load_runtime(agent::HSAAgent)
isa = get_first_isa(agent)
function load_runtime(device::RuntimeDevice)
isa = default_isa(device)
name = "amdgpunative.$isa.bc"
path = joinpath(@__DIR__, "..", "..", "deps", "runtime", name)
mkpath(dirname(path))
Expand All @@ -167,7 +167,7 @@ function load_runtime(agent::HSAAgent)
end
else
@info "Building the AMDGPUnative run-time library for your $isa device, this might take a while..."
lib = build_runtime(agent)
lib = build_runtime(device)
open(path, "w") do io
write(io, lib)
end
Expand Down
49 changes: 33 additions & 16 deletions src/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
export @roc, rocconvert, rocfunction

struct Kernel{F,TT}
agent::HSAAgent
device::RuntimeDevice
mod::ROCModule
fun::ROCFunction
end

# `split_kwargs()` segregates keyword arguments passed to `@roc` into those
# affecting the compiler, kernel execution, or both.
function split_kwargs(kwargs)
compiler_kws = [:agent, :queue, :name]
call_kws = [:groupsize, :gridsize, :agent, :queue]
# TODO: Alias groupsize and gridsize as threads and blocks, respectively
compiler_kws = [:device, :agent, :queue, :name]
call_kws = [:groupsize, :gridsize, :device, :agent, :queue]
compiler_kwargs = []
call_kwargs = []
for kwarg in kwargs
Expand Down Expand Up @@ -60,6 +61,23 @@ function assign_args!(code, args)
return vars, var_exprs
end

function extract_device(;device=nothing, agent=nothing, kwargs...)
if device !== nothing
return device
elseif agent !== nothing
return agent
else
return default_device()
end
end
function extract_queue(device; queue=nothing, kwargs...)
if queue !== nothing
return queue
else
return default_queue(device)
end
end

# fast lookup of global world age
world_age() = ccall(:jl_get_tls_world_age, UInt, ())

Expand Down Expand Up @@ -125,13 +143,12 @@ macro roc(ex...)
GC.@preserve $(vars...) begin
local kernel_args = map(rocconvert, ($(var_exprs...),))
local kernel_tt = Tuple{Core.Typeof.(kernel_args)...}
local agent = get_default_agent()
local kernel = rocfunction(agent, $(esc(f)), kernel_tt;
local device = extract_device(; $(map(esc, call_kwargs)...))
local kernel = rocfunction(device, $(esc(f)), kernel_tt;
$(map(esc, compiler_kwargs)...))
local queue = get_default_queue(agent)
local signal = HSASignal()
kernel(queue, signal, kernel_args...; $(map(esc, call_kwargs)...))
wait(signal)
local queue = extract_queue(device; $(map(esc, call_kwargs)...))
local event = kernel(queue, kernel_args...; $(map(esc, call_kwargs)...))
wait(event)
end
end)
return code
Expand Down Expand Up @@ -188,7 +205,7 @@ The output of this function is automatically cached, i.e. you can simply call
generated automatically, when the function changes, or when different types or
keyword arguments are provided.
"""
@generated function rocfunction(agent::HSAAgent, f::Core.Function, tt::Type=Tuple{}; name=nothing, kwargs...)
@generated function rocfunction(device::RuntimeDevice, f::Core.Function, tt::Type=Tuple{}; name=nothing, kwargs...)
tt = Base.to_tuple_type(tt.parameters[1])
sig = Base.signature_type(f, tt)
t = Tuple(tt.parameters)
Expand Down Expand Up @@ -217,8 +234,8 @@ keyword arguments are provided.

# compile the function
if !haskey(compilecache, key)
fun, mod = compile(:roc, agent, f, tt; name=name, kwargs...)
kernel = Kernel{f,tt}(agent, mod, fun)
fun, mod = compile(:roc, device, f, tt; name=name, kwargs...)
kernel = Kernel{f,tt}(device, mod, fun)
compilecache[key] = kernel
end

Expand All @@ -227,10 +244,10 @@ keyword arguments are provided.
end

rocfunction(f::Core.Function, tt::Type=Tuple{}; kwargs...) =
rocfunction(get_default_agent(), f, tt; kwargs...)
rocfunction(default_device(), f, tt; kwargs...)

@generated function call(kernel::Kernel{F,TT}, queue::HSAQueue,
signal::HSASignal, args...; call_kwargs...) where {F,TT}
@generated function call(kernel::Kernel{F,TT}, queue::RuntimeQueue,
args...; call_kwargs...) where {F,TT}

sig = Base.signature_type(F, TT)
args = (:F, (:( args[$i] ) for i in 1:length(args))...)
Expand All @@ -254,7 +271,7 @@ rocfunction(f::Core.Function, tt::Type=Tuple{}; kwargs...) =

quote
Base.@_inline_meta
roccall(queue, signal, kernel.fun, $call_tt, $(call_args...); call_kwargs...)
roccall(queue, kernel.fun, $call_tt, $(call_args...); call_kwargs...)
end
end

Expand Down
Loading