diff --git a/Manifest.toml b/Manifest.toml index c2f9d9c..0e555eb 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -49,11 +49,11 @@ uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[GPUCompiler]] deps = ["DataStructures", "InteractiveUtils", "LLVM", "Libdl", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "3f405c2aab2ef755d022bd57466a35c6d08a1531" -repo-rev = "158cd601fc42faed088785a7bde16436cbaa6017" +git-tree-sha1 = "65f7395a1245635f0c2279649fdbef09a1b0aa7b" +repo-rev = "master" repo-url = "https://github.com/JuliaGPU/GPUCompiler.jl.git" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.3.0" +version = "0.4.0" [[HSARuntime]] deps = ["CEnum", "Libdl", "Setfield"] diff --git a/Project.toml b/Project.toml index a1270b9..7a08d78 100644 --- a/Project.toml +++ b/Project.toml @@ -14,7 +14,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" [compat] Adapt = "0.4, 1.0" BinaryProvider = "0.5" -GPUCompiler = "0.3" +GPUCompiler = "0.4" HSARuntime = "0.3" LLVM = "1.3" Requires = "1" diff --git a/src/AMDGPUnative.jl b/src/AMDGPUnative.jl index beda9e4..b322162 100644 --- a/src/AMDGPUnative.jl +++ b/src/AMDGPUnative.jl @@ -37,6 +37,7 @@ include(joinpath("device", "globals.jl")) include("compiler.jl") include("execution_utils.jl") include("execution.jl") +include("exceptions.jl") include("reflection.jl") function __init__() diff --git a/src/compiler.jl b/src/compiler.jl index d76900f..5b2e21f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -15,7 +15,14 @@ function GPUCompiler.process_module!(job::ROCCompilerJob, mod::LLVM.Module) invoke(GPUCompiler.process_module!, Tuple{CompilerJob{GCNCompilerTarget}, typeof(mod)}, job, mod) - #emit_exception_flag!(mod) + # Run this early (before optimization) to ensure we link OCKL + emit_exception_user!(mod) +end +function GPUCompiler.finish_module!(job::ROCCompilerJob, mod::LLVM.Module) + invoke(GPUCompiler.finish_module!, + Tuple{CompilerJob{GCNCompilerTarget}, typeof(mod)}, + job, mod) + delete_exception_user!(mod) end function GPUCompiler.link_libraries!(job::ROCCompilerJob, mod::LLVM.Module, diff --git a/src/device/gcn.jl b/src/device/gcn.jl index 027cee2..d4d303d 100644 --- a/src/device/gcn.jl +++ b/src/device/gcn.jl @@ -1,10 +1,13 @@ -if Base.libllvm_version >= v"7.0" - include(joinpath("gcn", "math.jl")) -end +# HSA dispatch packet offsets +_packet_names = fieldnames(HSA.KernelDispatchPacket) +_packet_offsets = fieldoffset.(HSA.KernelDispatchPacket, 1:length(_packet_names)) + +include(joinpath("gcn", "math.jl")) include(joinpath("gcn", "indexing.jl")) include(joinpath("gcn", "assertion.jl")) include(joinpath("gcn", "synchronization.jl")) include(joinpath("gcn", "memory_static.jl")) -include(joinpath("gcn", "memory_dynamic.jl")) include(joinpath("gcn", "hostcall.jl")) include(joinpath("gcn", "output.jl")) +include(joinpath("gcn", "memory_dynamic.jl")) +include(joinpath("gcn", "execution_control.jl")) diff --git a/src/device/gcn/execution_control.jl b/src/device/gcn/execution_control.jl new file mode 100644 index 0000000..190cbcd --- /dev/null +++ b/src/device/gcn/execution_control.jl @@ -0,0 +1,41 @@ +## completion signal + +const completion_signal_base = _packet_offsets[findfirst(x->x==:completion_signal,_packet_names)] + +@generated function _completion_signal() + T_int8 = LLVM.Int8Type(JuliaContext()) + T_int64 = LLVM.Int64Type(JuliaContext()) + _as = convert(Int, AS.Constant) + T_ptr_i8 = LLVM.PointerType(T_int8, _as) + T_ptr_i64 = LLVM.PointerType(T_int64, _as) + + # create function + llvm_f, _ = create_function(T_int64) + mod = LLVM.parent(llvm_f) + + # generate IR + Builder(JuliaContext()) do builder + entry = BasicBlock(llvm_f, "entry", JuliaContext()) + position!(builder, entry) + + # get the kernel dispatch pointer + intr_typ = LLVM.FunctionType(T_ptr_i8) + intr = LLVM.Function(mod, "llvm.amdgcn.dispatch.ptr", intr_typ) + ptr = call!(builder, intr) + + # load the index + signal_ptr_i8 = inbounds_gep!(builder, ptr, [ConstantInt(completion_signal_base, JuliaContext())]) + signal_ptr = bitcast!(builder, signal_ptr_i8, T_ptr_i64) + signal = load!(builder, signal_ptr) + ret!(builder, signal) + end + + call_function(llvm_f, UInt64) +end + +signal_completion(value::Int64) = device_signal_store!(_completion_signal(), value) + +## misc. intrinsics +@inline sendmsg(x1, x2=Int32(0)) = ccall("llvm.amdgcn.s.sendmsg", llvmcall, Cvoid, (Int32, Int32), x1, x2) +@inline sendmsghalt(x1, x2=Int32(0)) = ccall("llvm.amdgcn.s.sendmsghalt", llvmcall, Cvoid, (Int32, Int32), x1, x2) +@inline endpgm() = @asmcall("s_endpgm", "", true) diff --git a/src/device/gcn/indexing.jl b/src/device/gcn/indexing.jl index 9466b46..9935802 100644 --- a/src/device/gcn/indexing.jl +++ b/src/device/gcn/indexing.jl @@ -33,7 +33,7 @@ end @generated function _dim(::Val{base}, ::Val{off}, ::Val{range}, ::Type{T}) where {base, off, range, T} T_int8 = LLVM.Int8Type(JuliaContext()) T_int32 = LLVM.Int32Type(JuliaContext()) - _as = Base.libllvm_version < v"7.0" ? 2 : 4 + _as = convert(Int, AS.Constant) T_ptr_i8 = LLVM.PointerType(T_int8, _as) T_ptr_i32 = LLVM.PointerType(T_int32, _as) T_ptr_T = LLVM.PointerType(convert(LLVMType, T), _as) @@ -91,8 +91,6 @@ for dim in (:x, :y, :z) cufn = Symbol("blockIdx_$dim") @eval @inline $cufn() = $fn() end -_packet_names = fieldnames(HSA.KernelDispatchPacket) -_packet_offsets = fieldoffset.(HSA.KernelDispatchPacket, 1:length(_packet_names)) for (dim,off) in ((:x,1), (:y,2), (:z,3)) # Workitem dimension fn = Symbol("workgroupDim_$dim") diff --git a/src/device/gcn/memory_dynamic.jl b/src/device/gcn/memory_dynamic.jl index 186c99f..06489b9 100644 --- a/src/device/gcn/memory_dynamic.jl +++ b/src/device/gcn/memory_dynamic.jl @@ -1,4 +1,75 @@ -export malloc +export malloc, free -# Stub implementation -malloc(::Csize_t) = C_NULL +function malloc(sz::Csize_t) + malloc_gbl = get_global_pointer(Val(:__global_malloc_hostcall), + HostCall{UInt64,DevicePtr{UInt8,AS.Global},Tuple{Csize_t}}) + malloc_hc = Base.unsafe_load(malloc_gbl) + ptr = hostcall!(malloc_hc, sz) + if UInt64(ptr) != 0 + kernel_metadata_insert!(ptr, sz) + end + return ptr +end + +function free(ptr::DevicePtr{T,AS.Global}) where T + free_gbl = get_global_pointer(Val(:__global_free_hostcall), + HostCall{UInt64,Nothing,Tuple{DevicePtr{UInt8,AS.Global}}}) + free_hc = Base.unsafe_load(free_gbl) + hostcall!(free_hc, Base.unsafe_convert(DevicePtr{UInt8,AS.Global}, ptr)) + kernel_metadata_delete!(ptr) +end + +# metadata store +struct MetadataInsertException <: Exception + kern::UInt64 +end +function Base.showerror(io::IO, mae::MetadataInsertException) + print(io, "MetadataInsertException: Failed to insert metadata for kernel ") + print(io, mae.kern) +end +function kernel_metadata_insert!(ptr, sz) + metadata_gbl = get_global_pointer(Val(:__global_metadata_store), KernelMetadata) + offset = 1 + while true + # FIXME: atomic_load + metadata = Base.unsafe_load(metadata_gbl, offset) + if metadata.kern == 0 + # empty metadata slot, use it + # FIXME: atomic_store! + Base.unsafe_store!(metadata_gbl, KernelMetadata(_completion_signal(), ptr, sz), offset) + return true + elseif metadata.kern == 1 + # tail slot, error + # FIXME: throw(MetadataInsertException(_completion_signal())) + return false + else + # slot in use, skip it + offset += 1 + end + end +end +function kernel_metadata_delete!(ptr) + metadata_gbl = get_global_pointer(Val(:__global_metadata_store), KernelMetadata) + offset = 1 + our_signal = _completion_signal() + while true + # FIXME: atomic_load + metadata = Base.unsafe_load(metadata_gbl, offset) + if metadata.kern == our_signal + # our slot, clear it + # FIXME: atomic_store! + metadata_gbl_ptr = convert(DevicePtr{UInt8,AS.Global}, + Base.unsafe_convert(Ptr{KernelMetadata}, metadata_gbl) + + (sizeof(KernelMetadata)*(offset-1))) + memset!(metadata_gbl_ptr, 0x0, Csize_t(sizeof(KernelMetadata))) + return true + elseif metadata.kern == 1 + # tail slot, error + # FIXME: throw(MetadataDeleteException(_completion_signal())) + return false + else + # not our slot, skip it + offset += 1 + end + end +end diff --git a/src/device/gcn/memory_static.jl b/src/device/gcn/memory_static.jl index 76d2ebd..812c60c 100644 --- a/src/device/gcn/memory_static.jl +++ b/src/device/gcn/memory_static.jl @@ -41,3 +41,74 @@ export alloc_special call_function(llvm_f, DevicePtr{T,as}) end + +@inline @generated function alloc_string(::Val{str}) where str + T_pint8_generic = LLVM.PointerType(LLVM.Int8Type(JuliaContext()), convert(Int, AS.Generic)) + llvm_f, _ = create_function(LLVM.Int64Type(JuliaContext())) + Builder(JuliaContext()) do builder + entry = BasicBlock(llvm_f, "entry", JuliaContext()) + position!(builder, entry) + str_ptr = globalstring_ptr!(builder, String(str)) + str_ptr_i64 = ptrtoint!(builder, str_ptr, LLVM.Int64Type(JuliaContext())) + ret!(builder, str_ptr_i64) + end + call_function(llvm_f, DevicePtr{UInt8,AS.Generic}) +end + +# TODO: Support various types of len +@inline @generated function memcpy!(dest_ptr::DevicePtr{UInt8,DestAS}, src_ptr::DevicePtr{UInt8,SrcAS}, len::LT) where {DestAS,SrcAS,LT<:Union{Int64,UInt64}} + T_nothing = LLVM.VoidType(JuliaContext()) + dest_as = convert(Int, DestAS) + src_as = convert(Int, SrcAS) + T_int8 = LLVM.Int8Type(JuliaContext()) + T_int64 = LLVM.Int64Type(JuliaContext()) + T_pint8_dest = LLVM.PointerType(T_int8, dest_as) + T_pint64_dest = LLVM.PointerType(T_int64, dest_as) + T_pint8_src = LLVM.PointerType(T_int8, src_as) + T_pint64_src = LLVM.PointerType(T_int64, src_as) + T_int1 = LLVM.Int1Type(JuliaContext()) + + llvm_f, _ = create_function(T_nothing, [T_int64, T_int64, T_int64]) + mod = LLVM.parent(llvm_f) + T_intr = LLVM.FunctionType(T_nothing, [T_pint8_dest, T_pint8_src, T_int64, T_int1]) + intr = LLVM.Function(mod, "llvm.memcpy.p$(dest_as)i8.p$(src_as)i8.i64", T_intr) + Builder(JuliaContext()) do builder + entry = BasicBlock(llvm_f, "entry", JuliaContext()) + position!(builder, entry) + + dest_ptr_i64 = inttoptr!(builder, parameters(llvm_f)[1], T_pint64_dest) + dest_ptr_i8 = bitcast!(builder, dest_ptr_i64, T_pint8_dest) + + src_ptr_i64 = inttoptr!(builder, parameters(llvm_f)[2], T_pint64_src) + src_ptr_i8 = bitcast!(builder, src_ptr_i64, T_pint8_src) + + call!(builder, intr, [dest_ptr_i8, src_ptr_i8, parameters(llvm_f)[3], ConstantInt(T_int1, 0)]) + ret!(builder) + end + call_function(llvm_f, Nothing, Tuple{DevicePtr{UInt8,DestAS},DevicePtr{UInt8,SrcAS},LT}, :((dest_ptr, src_ptr, len))) +end +@inline @generated function memset!(dest_ptr::DevicePtr{UInt8,DestAS}, value::UInt8, len::LT) where {DestAS,LT<:Union{Int64,UInt64}} + T_nothing = LLVM.VoidType(JuliaContext()) + dest_as = convert(Int, DestAS) + T_int8 = LLVM.Int8Type(JuliaContext()) + T_int64 = LLVM.Int64Type(JuliaContext()) + T_pint8_dest = LLVM.PointerType(T_int8, dest_as) + T_pint64_dest = LLVM.PointerType(T_int64, dest_as) + T_int1 = LLVM.Int1Type(JuliaContext()) + + llvm_f, _ = create_function(T_nothing, [T_int64, T_int8, T_int64]) + mod = LLVM.parent(llvm_f) + T_intr = LLVM.FunctionType(T_nothing, [T_pint8_dest, T_int8, T_int64, T_int1]) + intr = LLVM.Function(mod, "llvm.memset.p$(dest_as)i8.i64", T_intr) + Builder(JuliaContext()) do builder + entry = BasicBlock(llvm_f, "entry", JuliaContext()) + position!(builder, entry) + + dest_ptr_i64 = inttoptr!(builder, parameters(llvm_f)[1], T_pint64_dest) + dest_ptr_i8 = bitcast!(builder, dest_ptr_i64, T_pint8_dest) + + call!(builder, intr, [dest_ptr_i8, parameters(llvm_f)[2], parameters(llvm_f)[3], ConstantInt(T_int1, 0)]) + ret!(builder) + end + call_function(llvm_f, Nothing, Tuple{DevicePtr{UInt8,DestAS},UInt8,LT}, :((dest_ptr, value, len))) +end diff --git a/src/device/gcn/output.jl b/src/device/gcn/output.jl index 78ff6f7..f4fd068 100644 --- a/src/device/gcn/output.jl +++ b/src/device/gcn/output.jl @@ -24,6 +24,8 @@ function OutputContext(io::IO=stdout; agent=get_default_agent(), buf_len=2^16, k OutputContext(hc) end +const GLOBAL_OUTPUT_CONTEXT_TYPE = OutputContext{HostCall{UInt64,Int64,Tuple{DeviceStaticString{2^16}}}} + ### macros macro rocprint(oc, str) @@ -33,9 +35,30 @@ macro rocprintln(oc, str) rocprint(oc, str, true) end +macro rocprint(str) + @gensym oc_ptr oc + ex = quote + $(esc(oc_ptr)) = AMDGPUnative.get_global_pointer(Val(:__global_output_context), + $GLOBAL_OUTPUT_CONTEXT_TYPE) + $(esc(oc)) = Base.unsafe_load($(esc(oc_ptr))) + end + push!(ex.args, rocprint(oc, str)) + ex +end +macro rocprintln(str) + @gensym oc_ptr oc + ex = quote + $(esc(oc_ptr)) = AMDGPUnative.get_global_pointer(Val(:__global_output_context), + $GLOBAL_OUTPUT_CONTEXT_TYPE) + $(esc(oc)) = Base.unsafe_load($(esc(oc_ptr))) + end + push!(ex.args, rocprint(oc, str, true)) + ex +end + ### parse-time helpers -function rocprint(oc, str, nl=false) +function rocprint(oc, str, nl::Bool=false) ex = Expr(:block) if !(str isa Expr) str = Expr(:string, str) @@ -50,18 +73,14 @@ function rocprint(oc, str, nl=false) dstr = DeviceStaticString{N}() push!(ex.args, :(hostcall!($(esc(oc)).hostcall, $dstr))) end + push!(ex.args, :(nothing)) return ex end function rocprint!(ex, N, oc, str::String) - # TODO: push!(ex.args, :($rocprint!($(esc(oc)), $(Val(Symbol(str)))))) - off = N - ptr = :(Base.unsafe_convert(DevicePtr{UInt8,AS.Global}, $(esc(oc)).hostcall.buf_ptr)) - for byte in codeunits(str) - push!(ex.args, :(Base.unsafe_store!($ptr, $byte, $off))) - off += 1 - end - - return off + @gensym str_ptr + push!(ex.args, :($str_ptr = AMDGPUnative.alloc_string($(Val(Symbol(str)))))) + push!(ex.args, :(AMDGPUnative.memcpy!($(esc(oc)).hostcall.buf_ptr+$(N-1), $str_ptr, $(length(str))))) + return N+length(str) end function rocprint!(ex, N, oc, char::Char) @assert char == '\0' "Non-null chars not yet implemented" @@ -84,29 +103,3 @@ end =# ### runtime helpers - -#= TODO: LLVM hates me, but this should eventually work -# FIXME: Pass N and offset oc.buf_ptr appropriately -@inline @generated function rocprint!(oc::OutputContext, ::Val{str}) where str - T_int1 = LLVM.Int1Type(JuliaContext()) - T_int32 = LLVM.Int32Type(JuliaContext()) - T_pint8 = LLVM.PointerType(LLVM.Int8Type(JuliaContext())) - T_pint8_global = LLVM.PointerType(LLVM.Int8Type(JuliaContext()), convert(Int, AS.Global)) - T_nothing = LLVM.VoidType(JuliaContext()) - llvm_f, _ = create_function(T_nothing, [T_pint8_global]) - mod = LLVM.parent(llvm_f) - T_intr = LLVM.FunctionType(T_nothing, [T_pint8_global, T_pint8, T_int32, T_int32, T_int1]) - intr = LLVM.Function(mod, "llvm.memcpy.p1i8.p0i8.i32", T_intr) - Builder(JuliaContext()) do builder - entry = BasicBlock(llvm_f, "entry", JuliaContext()) - position!(builder, entry) - str_ptr = globalstring_ptr!(builder, String(str)) - buf_ptr = parameters(llvm_f)[1] - # NOTE: There's a hidden alignment parameter (argument 4) that's not documented in the LangRef - call!(builder, intr, [buf_ptr, str_ptr, ConstantInt(Int32(length(string(str))), JuliaContext()), ConstantInt(Int32(2), JuliaContext()), ConstantInt(T_int1, 0)]) - ret!(builder) - end - Core.println(unsafe_string(LLVM.API.LLVMPrintValueToString(LLVM.ref(llvm_f)))) - call_function(llvm_f, Nothing, Tuple{DevicePtr{UInt8,AS.Global}}, :((oc.hostcall.buf_ptr,))) -end -=# diff --git a/src/device/pointer.jl b/src/device/pointer.jl index 05c2b92..1371f42 100644 --- a/src/device/pointer.jl +++ b/src/device/pointer.jl @@ -91,23 +91,13 @@ Base.:(+)(x::Integer, y::DevicePtr) = y + x # memory operations -@static if Base.libllvm_version < v"7.0" - # Old values (LLVM 6) - Base.convert(::Type{Int}, ::Type{AS.Private}) = 0 - Base.convert(::Type{Int}, ::Type{AS.Global}) = 1 - Base.convert(::Type{Int}, ::Type{AS.Constant}) = 2 - Base.convert(::Type{Int}, ::Type{AS.Local}) = 3 - Base.convert(::Type{Int}, ::Type{AS.Generic}) = 4 - Base.convert(::Type{Int}, ::Type{AS.Region}) = 5 -else - # New values (LLVM 7+) - Base.convert(::Type{Int}, ::Type{AS.Generic}) = 0 - Base.convert(::Type{Int}, ::Type{AS.Global}) = 1 - Base.convert(::Type{Int}, ::Type{AS.Region}) = 2 - Base.convert(::Type{Int}, ::Type{AS.Local}) = 3 - Base.convert(::Type{Int}, ::Type{AS.Constant}) = 4 - Base.convert(::Type{Int}, ::Type{AS.Private}) = 5 -end +# New values (LLVM 7+) +Base.convert(::Type{Int}, ::Type{AS.Generic}) = 0 +Base.convert(::Type{Int}, ::Type{AS.Global}) = 1 +Base.convert(::Type{Int}, ::Type{AS.Region}) = 2 +Base.convert(::Type{Int}, ::Type{AS.Local}) = 3 +Base.convert(::Type{Int}, ::Type{AS.Constant}) = 4 +Base.convert(::Type{Int}, ::Type{AS.Private}) = 5 function tbaa_make_child(name::String, constant::Bool=false; ctx::LLVM.Context=JuliaContext()) tbaa_root = MDNode([MDString("roctbaa", ctx)], ctx) diff --git a/src/device/runtime.jl b/src/device/runtime.jl index 62635e4..1d77b2f 100644 --- a/src/device/runtime.jl +++ b/src/device/runtime.jl @@ -15,48 +15,57 @@ function load_runtime(dev_isa::String) GPUCompiler.load_runtime(job) end -#@inline exception_flag() = ccall("extern julia_exception_flag", llvmcall, Ptr{Cvoid}, ()) - function signal_exception() -#= - ptr = exception_flag() - if ptr !== C_NULL - unsafe_store!(convert(Ptr{Int}, ptr), 1) - threadfence_system() - else - @rocprintf(""" - WARNING: could not signal exception status to the host, execution will continue. - Please file a bug. - """) - end -=# + flag_ptr = get_global_pointer(Val(:__global_exception_flag), Int64) + Base.unsafe_store!(flag_ptr, 1) + # TODO: threadfence_system() + signal_completion(0) + sendmsghalt(5) # stop wavefront generation + sendmsghalt(6) # halt all running wavefronts + # TODO: endpgm() return end function report_exception(ex) -#= + #= FIXME @rocprintf(""" ERROR: a %s was thrown during kernel execution. Run Julia on debug level 2 for device stack traces. """, ex) -=# + =# + @rocprint(""" + ERROR: an exception was thrown during kernel execution. + Run Julia on debug level 2 for device stack traces. + """) + # TODO: Pass exception info and kernel ID to a global + #= + ring_ptr = get_global_pointer(Val(:__global_exception_ring), ExceptionEntry) + ee = ExceptionEntry(_completion_signal(), C_NULL) # FIXME: DevicePtr(pointer(ex))) + Base.unsafe_store!(ring_ptr, ee) + =# return end -report_oom(sz) = nothing #@rocprintf("ERROR: Out of dynamic GPU memory (trying to allocate %i bytes)\n", sz) +# FIXME: report_oom(sz) = @rocprintf("ERROR: Out of dynamic GPU memory (trying to allocate %i bytes)\n", sz) +report_oom(sz) = @rocprintln("ERROR: Out of dynamic GPU memory") function report_exception_name(ex) -#= + #= FIXME @rocprintf(""" ERROR: a %s was thrown during kernel execution. Stacktrace: """, ex) -=# + =# + @rocprint(""" + ERROR: an exception was thrown during kernel execution. + Stacktrace: + """) return end function report_exception_frame(idx, func, file, line) - #@rocprintf(" [%i] %s at %s:%i\n", idx, func, file, line) + # FIXME: @rocprintf(" [%i] %s at %s:%i\n", idx, func, file, line) + #@rocprintln(" [%i] %s at %s:%i") return end @@ -99,6 +108,7 @@ end function link_device_libs!(mod::LLVM.Module, dev_isa::String, undefined_fns) libs::Vector{LLVM.Module} = load_device_libs(dev_isa) + ufns = undefined_fns # TODO: only link if used # TODO: make these globally/locally configurable link_oclc_defaults!(mod, dev_isa) diff --git a/src/device/tools.jl b/src/device/tools.jl index 128538a..0dffa25 100644 --- a/src/device/tools.jl +++ b/src/device/tools.jl @@ -333,6 +333,6 @@ llvmsize(::LLVM.LLVMHalf) = sizeof(Float16) llvmsize(::LLVM.LLVMFloat) = sizeof(Float32) llvmsize(::LLVM.LLVMDouble) = sizeof(Float64) llvmsize(::LLVM.IntegerType) = div(Int(intwidth(GenericValue(LLVM.Int128Type(), -1))), 8) -llvmsize(ty::LLVM.ArrayType) = length*llvmsize(eltype(ty)) +llvmsize(ty::LLVM.ArrayType) = length(ty)*llvmsize(eltype(ty)) # TODO: VectorType, StructType, PointerType llvmsize(ty) = error("Unknown size for type: $ty, typeof: $(typeof(ty))") diff --git a/src/exceptions.jl b/src/exceptions.jl new file mode 100644 index 0000000..f65e1d0 --- /dev/null +++ b/src/exceptions.jl @@ -0,0 +1,57 @@ +# support for device-side exceptions (from CUDAnative/src/exceptions.jl) + +## exception type + +struct KernelException <: Exception + dev::RuntimeDevice +end + +function Base.showerror(io::IO, err::KernelException) + print(io, "KernelException: exception thrown during kernel execution on device $(err.dev.device)") +end + +## exception ring buffer + +struct ExceptionEntry + kern_id::UInt64 + ptr::DevicePtr{Any,AS.Global} +end +ExceptionEntry() = ExceptionEntry(0, DevicePtr{Any,AS.Global}(0)) + +## exception codegen + +# emit a global variable for storing the current exception status +function emit_exception_user!(mod::LLVM.Module) + # add a fake user for __ockl_hsa_signal_store and __ockl_hsa_signal_load + if !haskey(LLVM.functions(mod), "__fake_global_exception_flag_user") + ctx = JuliaContext() + ft = LLVM.FunctionType(LLVM.VoidType(ctx)) + fn = LLVM.Function(mod, "__fake_global_exception_flag_user", ft) + Builder(ctx) do builder + entry = BasicBlock(fn, "entry", ctx) + position!(builder, entry) + T_nothing = LLVM.VoidType(ctx) + T_i32 = LLVM.Int32Type(ctx) + T_i64 = LLVM.Int64Type(ctx) + T_signal_store = LLVM.FunctionType(T_nothing, [T_i64, T_i64, T_i32]) + signal_store = LLVM.Function(mod, "__ockl_hsa_signal_store", T_signal_store) + call!(builder, signal_store, [ConstantInt(0,ctx), + ConstantInt(0,ctx), + # __ATOMIC_RELEASE == 3 + ConstantInt(Int32(3), JuliaContext())]) + T_signal_load = LLVM.FunctionType(T_i64, [T_i64, T_i32]) + signal_load = LLVM.Function(mod, "__ockl_hsa_signal_load", T_signal_load) + loaded_value = call!(builder, signal_load, [ConstantInt(0,ctx), + # __ATOMIC_ACQUIRE == 2 + ConstantInt(Int32(2), JuliaContext())]) + ret!(builder) + end + end + @assert haskey(LLVM.functions(mod), "__fake_global_exception_flag_user") +end +function delete_exception_user!(mod::LLVM.Module) + if haskey(LLVM.functions(mod), "__fake_global_exception_flag_user") + delete!(mod, LLVM.functions(mod)["__fake_global_exception_flag_user"]) + end + @assert !haskey(LLVM.functions(mod), "__fake_global_exception_flag_user") +end diff --git a/src/execution.jl b/src/execution.jl index a9429cc..5f963b0 100644 --- a/src/execution.jl +++ b/src/execution.jl @@ -183,7 +183,7 @@ macro roc(ex...) #local $device = $extract_device(; $(call_kwargs...)) local $kernel = $rocfunction($f, $kernel_tt; $(compiler_kwargs...)) #local $queue = $extract_queue($device; $(call_kwargs...)) - local $signal = $create_event() + local $signal = $create_event($kernel.mod.exe) $kernel($kernel_args...; signal=$signal, $(call_kwargs...)) $signal end @@ -285,9 +285,8 @@ end @doc (@doc AbstractKernel) HostKernel -@inline function roccall(kernel::HostKernel, tt, args...; config=nothing, kwargs...) +@inline function roccall(kernel::HostKernel, tt, args...; config=nothing, signal, kwargs...) queue = get(kwargs, :queue, default_queue(default_device())) - signal = get(kwargs, :signal, create_event()) if config !== nothing roccall(kernel.fun, tt, args...; kwargs..., config(kernel)..., queue=queue, signal=signal) else @@ -322,7 +321,7 @@ function _rocfunction(source::FunctionSpec; device=default_device(), queue=defau target = GCNCompilerTarget(; dev_isa=default_isa(device), kwargs...) params = ROCCompilerParams() job = CompilerJob(target, source, params) - obj, kernel_fn, undefined_fns, undefined_gbls = GPUCompiler.compile(:obj, job; strict=true) + obj, kernel_fn, undefined_fns, undefined_gbls = GPUCompiler.compile(:obj, job) # settings to JIT based on Julia's debug setting jit_options = Dict{Any,Any}() @@ -346,7 +345,67 @@ function _rocfunction(source::FunctionSpec; device=default_device(), queue=defau fun = ROCFunction(mod, kernel_fn) kernel = HostKernel{source.f,source.tt}(mod, fun) - #create_exceptions!(mod) + # initialize global output context + if any(x->x[1]==:__global_output_context, globals) + gbl = HSARuntime.get_global(exe, :__global_output_context) + gbl_ptr = Base.unsafe_convert(Ptr{GLOBAL_OUTPUT_CONTEXT_TYPE}, gbl) + oc = OutputContext(stdout) + Base.unsafe_store!(gbl_ptr, oc) + end + + # initialize global exception flag + if any(x->x[1]==:__global_exception_flag, globals) + gbl = HSARuntime.get_global(exe, :__global_exception_flag) + gbl_ptr = Base.unsafe_convert(Ptr{Int64}, gbl) + Base.unsafe_store!(gbl_ptr, 0) + + # TODO: initialize exception ring buffer + #= + @assert any(x->x[1]==:__global_exception_ring, globals) + gbl = HSARuntime.get_global(exe, :__global_exception_ring) + gbl_ptr = Base.unsafe_convert(Ptr{ExceptionEntry}, gbl) + erb = ExceptionEntry() + Base.unsafe_store!(gbl_ptr, erb) + =# + end + + # initialize global metadata store pointer + if any(x->x[1]==:__global_metadata_store, globals) + gbl = HSARuntime.get_global(exe, :__global_metadata_store) + gbl_ptr = Base.unsafe_convert(Ptr{Ptr{KernelMetadata}}, gbl) + # setup initial slots + for idx in 1:length(mod.metadata)-1 + mod.metadata[idx] = KernelMetadata(0) + end + # setup tail slot + mod.metadata[end] = KernelMetadata(1) + Base.unsafe_store!(gbl_ptr, pointer(mod.metadata)) + end + + # initialize malloc hostcall + if any(x->x[1]==:__global_malloc_hostcall, globals) + gbl = HSARuntime.get_global(exe, :__global_malloc_hostcall) + gbl_ptr = Base.unsafe_convert(Ptr{HostCall{UInt64,DevicePtr{UInt8,AS.Global},Tuple{Csize_t}}}, gbl) + hc = HostCall(DevicePtr{UInt8,AS.Global}, Tuple{Csize_t}; agent=device.device, continuous=true) do sz + @debug "Allocating $sz bytes for kernel on device $device" + return DevicePtr{UInt8,AS.Global}(Base.unsafe_convert(Ptr{UInt8}, + Mem.alloc(device.device, sz))) + end + Base.unsafe_store!(gbl_ptr, hc) + end + + # initialize free hostcall + if any(x->x[1]==:__global_free_hostcall, globals) + gbl = HSARuntime.get_global(exe, :__global_free_hostcall) + gbl_ptr = Base.unsafe_convert(Ptr{HostCall{UInt64,Nothing,Tuple{DevicePtr{UInt8,AS.Global}}}}, gbl) + hc = HostCall(Nothing, Tuple{DevicePtr{UInt8,AS.Global}}; agent=device.device, continuous=true) do ptr + @debug "Freeing $ptr for kernel on device $device" + buf = Mem.Buffer(Base.unsafe_convert(Ptr{Cvoid}, convert(Ptr{UInt8}, ptr)), 0, device.device) + Mem.free(buf) + return nothing + end + Base.unsafe_store!(gbl_ptr, hc) + end return kernel end diff --git a/src/execution_utils.jl b/src/execution_utils.jl index b5698d0..e8200b8 100644 --- a/src/execution_utils.jl +++ b/src/execution_utils.jl @@ -3,10 +3,19 @@ export ROCDim, ROCModule, ROCFunction, roccall +struct KernelMetadata + kern::UInt64 + data::DevicePtr{UInt8,AS.Global} + size::Csize_t +end +KernelMetadata(kern) = KernelMetadata(kern, DevicePtr{UInt8,AS.Global}(0), 0) mutable struct ROCModule{E} exe::RuntimeExecutable{E} options::Dict{Any,Any} + metadata::Vector{KernelMetadata} end +ROCModule(exe, options) = + ROCModule(exe, options, Vector{KernelMetadata}(undef,256)) mutable struct ROCFunction mod::ROCModule entry::String diff --git a/src/runtime.jl b/src/runtime.jl index 48630f1..e8ae746 100644 --- a/src/runtime.jl +++ b/src/runtime.jl @@ -18,9 +18,38 @@ default_isa(device::RuntimeDevice{HSAAgent}) = struct RuntimeEvent{E} event::E end -create_event() = RuntimeEvent(create_event(RUNTIME[])) -create_event(::typeof(HSA_rt)) = HSASignal() -Base.wait(event::RuntimeEvent; kwargs...) = wait(event.event; kwargs...) +create_event(exe) = RuntimeEvent(create_event(RUNTIME[], exe)) +Base.wait(event::RuntimeEvent, exe) = wait(event.event, exe) + +"Tracks the completion and status of a kernel's execution." +struct HSAStatusSignal + signal::HSASignal + exe::HSAExecutable +end +create_event(::typeof(HSA_rt), exe) = HSAStatusSignal(HSASignal(), exe.exe) +function Base.wait(event::RuntimeEvent{HSAStatusSignal}; kwargs...) + wait(event.event.signal; kwargs...) # wait for completion signal + exe = event.event.exe + agent = exe.agent + if haskey(exe.globals, :__global_exception_flag) + ex_flag = HSARuntime.get_global(exe, :__global_exception_flag) + ex_flag_ptr = Base.unsafe_convert(Ptr{Int64}, ex_flag) + ex_flag_value = Base.unsafe_load(ex_flag_ptr) + if ex_flag_value != 0 + if haskey(exe.globals, :__global_exception_ring) + ex_ring = HSARuntime.get_global(exe, :__global_exception_ring) + ex_ring_ptr = Base.unsafe_convert(Ptr{ExceptionEntry}, ex_ring) + ex_ring_value = Base.unsafe_load(ex_ring_ptr) + # FIXME: Check for and collect any exceptions, and clear their slots + # FIXME: Throw appropriate error + throw(KernelException(RuntimeDevice(agent))) + else + throw(KernelException(RuntimeDevice(agent))) + end + end + end +end + struct RuntimeExecutable{E} exe::E @@ -55,7 +84,9 @@ create_kernel(::typeof(HSA_rt), device, exe, entry, args) = HSAKernelInstance(device.device, exe.exe, entry, args) launch_kernel(queue, kern, event; kwargs...) = launch_kernel(RUNTIME[], queue, kern, event; kwargs...) -launch_kernel(::typeof(HSA_rt), queue, kern, event; - groupsize=nothing, gridsize=nothing) = - HSARuntime.launch!(queue.queue, kern.kernel, event.event; +function launch_kernel(::typeof(HSA_rt), queue, kern, event; + groupsize=nothing, gridsize=nothing) + signal = event.event isa HSAStatusSignal ? event.event.signal : event.event + HSARuntime.launch!(queue.queue, kern.kernel, signal; workgroup_size=groupsize, grid_size=gridsize) +end diff --git a/test/device/exceptions.jl b/test/device/exceptions.jl new file mode 100644 index 0000000..5c68685 --- /dev/null +++ b/test/device/exceptions.jl @@ -0,0 +1,12 @@ +@testset "Exceptions" begin + +function oob_kernel(X) + X[0] = 1 + nothing +end + +HA = HSAArray(ones(Float32, 4)) +_, msg = @grab_output(@test_throws AMDGPUnative.KernelException wait(@roc oob_kernel(HA)), stdout) +@test startswith(msg, "ERROR: an exception was thrown during kernel execution.\n") + +end diff --git a/test/device/memory.jl b/test/device/memory.jl index ddcfaf5..c81ef49 100644 --- a/test/device/memory.jl +++ b/test/device/memory.jl @@ -34,3 +34,49 @@ wait(@roc memory_static_kernel(HA, HB)) @test Array(HA) ≈ Array(HB) end + +@testset "Memory: Dynamic" begin + +function malloc_kernel(X) + ptr = AMDGPUnative.malloc(Csize_t(4)) + X[1] = ptr + AMDGPUnative.free(ptr) + nothing +end + +HA = HSAArray(zeros(UInt64, 1)) + +wait(@roc malloc_kernel(HA)) + +@test Array(HA)[1] != 0 + +end + +@testset "Memcpy/Memset" begin + +function memcpy_kernel(X,Y) + AMDGPUnative.memcpy!(Y.ptr, X.ptr, sizeof(Float32)*length(X)) + nothing +end + +A = rand(Float32, 4) +B = zeros(Float32, 4) +HA, HB = HSAArray.((A,B)) + +wait(@roc memcpy_kernel(A,B)) + +@test A == collect(HA) == collect(HB) + +function memset_kernel(X,y) + AMDGPUnative.memset!(X.ptr, y, div(length(X),2)) + nothing +end + +A = zeros(UInt8, 4) +HA = HSAArray(A) + +wait(@roc memset_kernel(X,0x3)) + +@test all(collect(HA) .== 0x3) + +end diff --git a/test/device/output.jl b/test/device/output.jl index 8bfe97a..09abaac 100644 --- a/test/device/output.jl +++ b/test/device/output.jl @@ -40,6 +40,17 @@ end @test String(take!(iob)) == "Hello World!Goodbye World!\n" end +@testset "Plain, global context" begin + function kernel(x) + @rocprint "Hello World!" + @rocprintln "Goodbye World!" + nothing + end + + _, msg = @grab_output wait(@roc kernel(1)) + @test msg == "Hello World!Goodbye World!\n" +end + #= TODO @testset "Interpolated string" begin inner_str = "to the" diff --git a/test/runtests.jl b/test/runtests.jl index ca21a46..0502133 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,8 @@ agent_name = HSARuntime.get_name(get_default_agent()) agent_isa = get_first_isa(get_default_agent()) @info "Testing using device $agent_name with ISA $agent_isa" +include("util.jl") + @testset "AMDGPUnative" begin @testset "Core" begin @@ -30,12 +32,9 @@ if AMDGPUnative.configured include("device/hostcall.jl") include("device/output.jl") include("device/globals.jl") - if Base.libllvm_version >= v"7.0" - include("device/math.jl") - else - @warn "Testing with LLVM 6; some tests will be disabled!" - @test_skip "Math Intrinsics" - end + include("device/math.jl") + include("device/exceptions.jl") + include("device/execution_control.jl") end end else diff --git a/test/util.jl b/test/util.jl new file mode 100644 index 0000000..4d31047 --- /dev/null +++ b/test/util.jl @@ -0,0 +1,20 @@ +# NOTE: based on test/pkg.jl::capture_stdout, but doesn't discard exceptions +macro grab_output(ex, io=stdout) + quote + mktemp() do fname, fout + ret = nothing + open(fname, "w") do fout + if $io == stdout + redirect_stdout(fout) do + ret = $(esc(ex)) + end + elseif $io == stderr + redirect_stderr(fout) do + ret = $(esc(ex)) + end + end + end + ret, read(fname, String) + end + end +end