Skip to content

Use contextual dispatch for device functions. #750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.4.1"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "de4f08843c332d355852721adb1592bce7924da3"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.29"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "919c7f3151e79ff196add81d7f4e45d91bbf420b"
Expand Down Expand Up @@ -77,14 +83,21 @@ version = "6.2.0"

[[GPUCompiler]]
deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "ef2839b063e158672583b9c09d2cf4876a8d3d55"
git-tree-sha1 = "b6c3b8e2df6ffe0da0b10e2045ce35a3cf618b8a"
repo-rev = "1ecbe42"
repo-url = "https://github.com/JuliaGPU/GPUCompiler.jl.git"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.10.0"
version = "0.10.1"

[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[JLLWrappers]]
git-tree-sha1 = "a431f5f2ca3f4feef3bd7a5e94b8b8d4f2f647a0"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.2.0"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194"
Expand Down Expand Up @@ -150,6 +163,12 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

[[OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.3+4"

[[OrderedCollections]]
git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Expand Down Expand Up @@ -205,6 +224,12 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["ChainRulesCore", "OpenSpecFun_jll"]
git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.3.0"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

Expand Down
5 changes: 0 additions & 5 deletions deps/compatibility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,6 @@ end
## high-level functions that return target and isa support

function llvm_compat(version=LLVM.version())
# https://github.com/JuliaGPU/CUDAnative.jl/issues/428
if version >= v"8.0" && VERSION < v"1.3.0-DEV.547"
error("LLVM 8.0 requires a newer version of Julia")
end

InitializeNVPTXTarget()

cap_support = sort(collect(llvm_cap_support(version)))
Expand Down
6 changes: 0 additions & 6 deletions examples/wmma/high-level.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
# Need https://github.com/JuliaLang/julia/pull/33970
# and https://github.com/JuliaLang/julia/pull/34043
if VERSION < v"1.5-"
exit()
end

using CUDA
if capability(device()) < v"7.0"
exit()
Expand Down
6 changes: 0 additions & 6 deletions examples/wmma/low-level.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
# Need https://github.com/JuliaLang/julia/pull/33970
# and https://github.com/JuliaLang/julia/pull/34043
if VERSION < v"1.5-"
exit()
end

using CUDA
if capability(device()) < v"7.0"
exit()
Expand Down
13 changes: 13 additions & 0 deletions src/CUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ using BFloat16s

using Memoize

using ExprTools


##

const ci_cache = GPUCompiler.CodeCache()

@static if isdefined(Base.Experimental, Symbol("@overlay"))
Base.Experimental.@MethodTable(method_table)
else
const method_table = nothing
end


## source code includes

Expand Down
2 changes: 0 additions & 2 deletions src/accumulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ function scan!(f::Function, output::AnyCuArray{T}, input::AnyCuArray;
dims > ndims(input) && return copyto!(output, input)
isempty(inds_t[dims]) && return output

f = cufunc(f)

# iteration domain across the main dimension
Rdim = CartesianIndices((size(input, dims),))

Expand Down
6 changes: 0 additions & 6 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -476,12 +476,6 @@ function Base.reshape(a::CuArray{T,M}, dims::NTuple{N,Int}) where {T,N,M}
return b
end

# allow missing dimensions with Colon()
if VERSION < v"1.6.0-DEV.1358"
Base.reshape(parent::CuArray, dims::Tuple{Vararg{Union{Int,Colon}}}) =
Base.reshape(parent, Base._reshape_uncolon(parent, dims))
end


## reinterpret

Expand Down
99 changes: 3 additions & 96 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,99 +14,6 @@ Base.similar(bc::Broadcasted{CuArrayStyle{N}}, ::Type{T}) where {N,T} =
Base.similar(bc::Broadcasted{CuArrayStyle{N}}, ::Type{T}, dims) where {N,T} =
CuArray{T}(undef, dims)


## replace base functions with libdevice alternatives

cufunc(f) = f
cufunc(::Type{T}) where T = (x...) -> T(x...) # broadcasting type ctors isn't GPU compatible

Broadcast.broadcasted(::CuArrayStyle{N}, f, args...) where {N} =
Broadcasted{CuArrayStyle{N}}(cufunc(f), args, nothing)

const device_intrinsics = :[
cos, cospi, sin, sinpi, tan, acos, asin, atan,
cosh, sinh, tanh, acosh, asinh, atanh, angle,
log, log10, log1p, log2, logb, ilogb,
exp, exp2, exp10, expm1, ldexp,
erf, erfinv, erfc, erfcinv, erfcx,
brev, clz, ffs, byte_perm, popc,
isfinite, isinf, isnan, nearbyint,
nextafter, signbit, copysign, abs,
sqrt, rsqrt, cbrt, rcbrt, pow,
ceil, floor, saturate,
lgamma, tgamma,
j0, j1, jn, y0, y1, yn,
normcdf, normcdfinv, hypot,
fma, sad, dim, mul24, mul64hi, hadd, rhadd, scalbn].args

for f in device_intrinsics
isdefined(Base, f) || continue
@eval cufunc(::typeof(Base.$f)) = $f
end

# broadcast ^

culiteral_pow(::typeof(^), x::T, ::Val{0}) where {T<:Real} = one(x)
culiteral_pow(::typeof(^), x::T, ::Val{1}) where {T<:Real} = x
culiteral_pow(::typeof(^), x::T, ::Val{2}) where {T<:Real} = x * x
culiteral_pow(::typeof(^), x::T, ::Val{3}) where {T<:Real} = x * x * x
culiteral_pow(::typeof(^), x::T, ::Val{p}) where {T<:Real,p} = pow(x, Int32(p))

cufunc(::typeof(Base.literal_pow)) = culiteral_pow
cufunc(::typeof(Base.:(^))) = pow

using MacroTools

const _cufuncs = [copy(device_intrinsics); :^]
cufuncs() = (global _cufuncs; _cufuncs)

_cuint(x::Int) = Int32(x)
_cuint(x::Expr) = x.head == :call && x.args[1] == :Int32 && x.args[2] isa Int ? Int32(x.args[2]) : x
_cuint(x) = x

function _cupowliteral(x::Expr)
if x.head == :call && x.args[1] == :(CUDA.cufunc(^)) && x.args[3] isa Int32
num = x.args[3]
if 0 <= num <= 3
sym = gensym(:x)
new_x = Expr(:block, :($sym = $(x.args[2])))

if iszero(num)
push!(new_x.args, :(one($sym)))
else
unroll = Expr(:call, :*)
for x = one(num):num
push!(unroll.args, sym)
end
push!(new_x.args, unroll)
end

x = new_x
end
end
x
end
_cupowliteral(x) = x

function replace_device(ex)
global _cufuncs
MacroTools.postwalk(ex) do x
x = x in _cufuncs ? :(CUDA.cufunc($x)) : x
x = _cuint(x)
x = _cupowliteral(x)
x
end
end

macro cufunc(ex)
global _cufuncs
def = MacroTools.splitdef(ex)
f = def[:name]
def[:name] = Symbol(:cu, f)
def[:body] = replace_device(def[:body])
push!(_cufuncs, f)
quote
$(esc(MacroTools.combinedef(def)))
CUDA.cufunc(::typeof($(esc(f)))) = $(esc(def[:name]))
end
end
# broadcasting type ctors isn't GPU compatible
Broadcast.broadcasted(::CuArrayStyle{N}, f::Type{T}, args...) where {N, T} =
Broadcasted{CuArrayStyle{N}}((x...) -> T(x...), args, nothing)
6 changes: 1 addition & 5 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,7 @@ AbstractKernel
args = (:F, (:( args[$i] ) for i in 1:length(args))...)

# filter out arguments that shouldn't be passed
predicate = if VERSION >= v"1.5.0-DEV.581"
dt -> isghosttype(dt) || Core.Compiler.isconstType(dt)
else
dt -> isghosttype(dt)
end
predicate = dt -> isghosttype(dt) || Core.Compiler.isconstType(dt)
to_pass = map(!predicate, sig.parameters)
call_t = Type[x[1] for x in zip(sig.parameters, to_pass) if x[2]]
call_args = Union{Expr,Symbol}[x[1] for x in zip(args, to_pass) if x[2]]
Expand Down
4 changes: 4 additions & 0 deletions src/compiler/gpucompiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ function GPUCompiler.link_libraries!(job::CUDACompilerJob, mod::LLVM.Module,
job, mod, undefined_fns)
link_libdevice!(mod, job.target.cap, undefined_fns)
end

GPUCompiler.ci_cache(::CUDACompilerJob) = ci_cache

GPUCompiler.method_table(::CUDACompilerJob) = method_table
29 changes: 29 additions & 0 deletions src/device/intrinsics.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,34 @@
# wrappers for functionality provided by the CUDA toolkit

const overrides = quote end

macro device_override(ex)
code = quote
$GPUCompiler.@override($method_table, $ex)
end
if isdefined(Base.Experimental, Symbol("@overlay"))
return esc(code)
else
push!(overrides.args, code)
return
end
end

macro device_function(ex)
ex = macroexpand(__module__, ex)
def = splitdef(ex)

# generate a function that errors
def[:body] = quote
error("This function is not intended for use on the CPU")
end

esc(quote
$(combinedef(def))
@device_override $ex
end)
end

# extensions to the C language
include("intrinsics/memory_shared.jl")
include("intrinsics/indexing.jl")
Expand Down
Loading