From 6f322bdfcfbbececd92e4158307f033ac0853dda Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 4 Mar 2021 09:49:39 +0100 Subject: [PATCH 1/3] Use contextual dispatch for device functions. --- Manifest.toml | 29 ++- Project.toml | 1 + src/CUDA.jl | 13 ++ src/accumulate.jl | 2 - src/broadcast.jl | 99 +-------- src/compiler/gpucompiler.jl | 4 + src/device/intrinsics.jl | 29 +++ src/device/intrinsics/math.jl | 403 +++++++++++++++++----------------- src/forwarddiff.jl | 87 -------- src/initialization.jl | 5 +- src/mapreduce.jl | 3 - src/sorting.jl | 3 - test/Project.toml | 1 - test/broadcast.jl | 23 -- test/device/intrinsics.jl | 2 +- test/forwarddiff.jl | 72 ------ 16 files changed, 283 insertions(+), 493 deletions(-) delete mode 100644 src/forwarddiff.jl delete mode 100644 test/forwarddiff.jl diff --git a/Manifest.toml b/Manifest.toml index 2e45971102..25aa064a06 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -32,6 +32,12 @@ git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.4.1" +[[ChainRulesCore]] +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "de4f08843c332d355852721adb1592bce7924da3" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "0.9.29" + [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] git-tree-sha1 = "919c7f3151e79ff196add81d7f4e45d91bbf420b" @@ -77,14 +83,21 @@ version = "6.2.0" [[GPUCompiler]] deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "ef2839b063e158672583b9c09d2cf4876a8d3d55" +git-tree-sha1 = "b6c3b8e2df6ffe0da0b10e2045ce35a3cf618b8a" +repo-rev = "1ecbe42" +repo-url = "https://github.com/JuliaGPU/GPUCompiler.jl.git" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.10.0" +version = "0.10.1" [[InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[JLLWrappers]] +git-tree-sha1 = "a431f5f2ca3f4feef3bd7a5e94b8b8d4f2f647a0" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.2.0" + [[LLVM]] deps = ["CEnum", "Libdl", "Printf", "Unicode"] git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194" @@ -150,6 +163,12 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159" [[NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +[[OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.3+4" + [[OrderedCollections]] git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" @@ -205,6 +224,12 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc" deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +[[SpecialFunctions]] +deps = ["ChainRulesCore", "OpenSpecFun_jll"] +git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "1.3.0" + [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/Project.toml b/Project.toml index b7b5368a49..84df7c6251 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" diff --git a/src/CUDA.jl b/src/CUDA.jl index 57711f04fe..f51f125c8f 100644 --- a/src/CUDA.jl +++ b/src/CUDA.jl @@ -18,6 +18,19 @@ using BFloat16s using Memoize +using ExprTools + + +## + +const ci_cache = GPUCompiler.CodeCache() + +@static if VERSION >= v"1.7-" +Base.Experimental.@MethodTable(method_table) +else +const method_table = nothing +end + ## source code includes diff --git a/src/accumulate.jl b/src/accumulate.jl index 69413db9db..e76617388a 100644 --- a/src/accumulate.jl +++ b/src/accumulate.jl @@ -134,8 +134,6 @@ function scan!(f::Function, output::AnyCuArray{T}, input::AnyCuArray; dims > ndims(input) && return copyto!(output, input) isempty(inds_t[dims]) && return output - f = cufunc(f) - # iteration domain across the main dimension Rdim = CartesianIndices((size(input, dims),)) diff --git a/src/broadcast.jl b/src/broadcast.jl index 0420cceaf4..fe11f7e1c1 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -14,99 +14,6 @@ Base.similar(bc::Broadcasted{CuArrayStyle{N}}, ::Type{T}) where {N,T} = Base.similar(bc::Broadcasted{CuArrayStyle{N}}, ::Type{T}, dims) where {N,T} = CuArray{T}(undef, dims) - -## replace base functions with libdevice alternatives - -cufunc(f) = f -cufunc(::Type{T}) where T = (x...) -> T(x...) # broadcasting type ctors isn't GPU compatible - -Broadcast.broadcasted(::CuArrayStyle{N}, f, args...) where {N} = - Broadcasted{CuArrayStyle{N}}(cufunc(f), args, nothing) - -const device_intrinsics = :[ - cos, cospi, sin, sinpi, tan, acos, asin, atan, - cosh, sinh, tanh, acosh, asinh, atanh, angle, - log, log10, log1p, log2, logb, ilogb, - exp, exp2, exp10, expm1, ldexp, - erf, erfinv, erfc, erfcinv, erfcx, - brev, clz, ffs, byte_perm, popc, - isfinite, isinf, isnan, nearbyint, - nextafter, signbit, copysign, abs, - sqrt, rsqrt, cbrt, rcbrt, pow, - ceil, floor, saturate, - lgamma, tgamma, - j0, j1, jn, y0, y1, yn, - normcdf, normcdfinv, hypot, - fma, sad, dim, mul24, mul64hi, hadd, rhadd, scalbn].args - -for f in device_intrinsics - isdefined(Base, f) || continue - @eval cufunc(::typeof(Base.$f)) = $f -end - -# broadcast ^ - -culiteral_pow(::typeof(^), x::T, ::Val{0}) where {T<:Real} = one(x) -culiteral_pow(::typeof(^), x::T, ::Val{1}) where {T<:Real} = x -culiteral_pow(::typeof(^), x::T, ::Val{2}) where {T<:Real} = x * x -culiteral_pow(::typeof(^), x::T, ::Val{3}) where {T<:Real} = x * x * x -culiteral_pow(::typeof(^), x::T, ::Val{p}) where {T<:Real,p} = pow(x, Int32(p)) - -cufunc(::typeof(Base.literal_pow)) = culiteral_pow -cufunc(::typeof(Base.:(^))) = pow - -using MacroTools - -const _cufuncs = [copy(device_intrinsics); :^] -cufuncs() = (global _cufuncs; _cufuncs) - -_cuint(x::Int) = Int32(x) -_cuint(x::Expr) = x.head == :call && x.args[1] == :Int32 && x.args[2] isa Int ? Int32(x.args[2]) : x -_cuint(x) = x - -function _cupowliteral(x::Expr) - if x.head == :call && x.args[1] == :(CUDA.cufunc(^)) && x.args[3] isa Int32 - num = x.args[3] - if 0 <= num <= 3 - sym = gensym(:x) - new_x = Expr(:block, :($sym = $(x.args[2]))) - - if iszero(num) - push!(new_x.args, :(one($sym))) - else - unroll = Expr(:call, :*) - for x = one(num):num - push!(unroll.args, sym) - end - push!(new_x.args, unroll) - end - - x = new_x - end - end - x -end -_cupowliteral(x) = x - -function replace_device(ex) - global _cufuncs - MacroTools.postwalk(ex) do x - x = x in _cufuncs ? :(CUDA.cufunc($x)) : x - x = _cuint(x) - x = _cupowliteral(x) - x - end -end - -macro cufunc(ex) - global _cufuncs - def = MacroTools.splitdef(ex) - f = def[:name] - def[:name] = Symbol(:cu, f) - def[:body] = replace_device(def[:body]) - push!(_cufuncs, f) - quote - $(esc(MacroTools.combinedef(def))) - CUDA.cufunc(::typeof($(esc(f)))) = $(esc(def[:name])) - end -end +# broadcasting type ctors isn't GPU compatible +Broadcast.broadcasted(::CuArrayStyle{N}, f::Type{T}, args...) where {N, T} = + Broadcasted{CuArrayStyle{N}}((x...) -> T(x...), args, nothing) diff --git a/src/compiler/gpucompiler.jl b/src/compiler/gpucompiler.jl index 928f0c36d1..50737ab198 100644 --- a/src/compiler/gpucompiler.jl +++ b/src/compiler/gpucompiler.jl @@ -39,3 +39,7 @@ function GPUCompiler.link_libraries!(job::CUDACompilerJob, mod::LLVM.Module, job, mod, undefined_fns) link_libdevice!(mod, job.target.cap, undefined_fns) end + +GPUCompiler.ci_cache(::CUDACompilerJob) = ci_cache + +GPUCompiler.method_table(::CUDACompilerJob) = method_table diff --git a/src/device/intrinsics.jl b/src/device/intrinsics.jl index ca03c0bf1f..18149ae6e7 100644 --- a/src/device/intrinsics.jl +++ b/src/device/intrinsics.jl @@ -1,5 +1,34 @@ # wrappers for functionality provided by the CUDA toolkit +const overrides = quote end + +macro device_override(ex) + code = quote + $GPUCompiler.@override($method_table, $ex) + end + if VERSION >= v"1.7-" + return esc(code) + else + push!(overrides.args, code) + return + end +end + +macro device_function(ex) + ex = macroexpand(__module__, ex) + def = splitdef(ex) + + # generate a function that errors + def[:body] = quote + error("This function is not intended for use on the CPU") + end + + esc(quote + $(combinedef(def)) + @device_override $ex + end) +end + # extensions to the C language include("intrinsics/memory_shared.jl") include("intrinsics/indexing.jl") diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index 43d05f81bd..18511b094a 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -1,305 +1,304 @@ # math functionality +using Base: FastMath + +using SpecialFunctions + + ## trigonometric -@inline cos(x::Float64) = ccall("extern __nv_cos", llvmcall, Cdouble, (Cdouble,), x) -@inline cos(x::Float32) = ccall("extern __nv_cosf", llvmcall, Cfloat, (Cfloat,), x) -@inline cos_fast(x::Float32) = ccall("extern __nv_fast_cosf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.cos(x::Float64) = ccall("extern __nv_cos", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.cos(x::Float32) = ccall("extern __nv_cosf", llvmcall, Cfloat, (Cfloat,), x) +@device_override FastMath.cos_fast(x::Float32) = ccall("extern __nv_fast_cosf", llvmcall, Cfloat, (Cfloat,), x) -@inline cospi(x::Float64) = ccall("extern __nv_cospi", llvmcall, Cdouble, (Cdouble,), x) -@inline cospi(x::Float32) = ccall("extern __nv_cospif", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.cospi(x::Float64) = ccall("extern __nv_cospi", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.cospi(x::Float32) = ccall("extern __nv_cospif", llvmcall, Cfloat, (Cfloat,), x) -@inline sin(x::Float64) = ccall("extern __nv_sin", llvmcall, Cdouble, (Cdouble,), x) -@inline sin(x::Float32) = ccall("extern __nv_sinf", llvmcall, Cfloat, (Cfloat,), x) -@inline sin_fast(x::Float32) = ccall("extern __nv_fast_sinf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.sin(x::Float64) = ccall("extern __nv_sin", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.sin(x::Float32) = ccall("extern __nv_sinf", llvmcall, Cfloat, (Cfloat,), x) +@device_override FastMath.sin_fast(x::Float32) = ccall("extern __nv_fast_sinf", llvmcall, Cfloat, (Cfloat,), x) -@inline sinpi(x::Float64) = ccall("extern __nv_sinpi", llvmcall, Cdouble, (Cdouble,), x) -@inline sinpi(x::Float32) = ccall("extern __nv_sinpif", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.sinpi(x::Float64) = ccall("extern __nv_sinpi", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.sinpi(x::Float32) = ccall("extern __nv_sinpif", llvmcall, Cfloat, (Cfloat,), x) -@inline tan(x::Float64) = ccall("extern __nv_tan", llvmcall, Cdouble, (Cdouble,), x) -@inline tan(x::Float32) = ccall("extern __nv_tanf", llvmcall, Cfloat, (Cfloat,), x) -@inline tan_fast(x::Float32) = ccall("extern __nv_fast_tanf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.tan(x::Float64) = ccall("extern __nv_tan", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.tan(x::Float32) = ccall("extern __nv_tanf", llvmcall, Cfloat, (Cfloat,), x) +@device_override FastMath.tan_fast(x::Float32) = ccall("extern __nv_fast_tanf", llvmcall, Cfloat, (Cfloat,), x) ## inverse trigonometric -@inline acos(x::Float64) = ccall("extern __nv_acos", llvmcall, Cdouble, (Cdouble,), x) -@inline acos(x::Float32) = ccall("extern __nv_acosf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.acos(x::Float64) = ccall("extern __nv_acos", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.acos(x::Float32) = ccall("extern __nv_acosf", llvmcall, Cfloat, (Cfloat,), x) -@inline asin(x::Float64) = ccall("extern __nv_asin", llvmcall, Cdouble, (Cdouble,), x) -@inline asin(x::Float32) = ccall("extern __nv_asinf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.asin(x::Float64) = ccall("extern __nv_asin", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.asin(x::Float32) = ccall("extern __nv_asinf", llvmcall, Cfloat, (Cfloat,), x) -@inline atan(x::Float64) = ccall("extern __nv_atan", llvmcall, Cdouble, (Cdouble,), x) -@inline atan(x::Float32) = ccall("extern __nv_atanf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.atan(x::Float64) = ccall("extern __nv_atan", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.atan(x::Float32) = ccall("extern __nv_atanf", llvmcall, Cfloat, (Cfloat,), x) -# ! atan2 is equivalent to Base.atan -@inline atan2(x::Float64, y::Float64) = ccall("extern __nv_atan2", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) -@inline atan2(x::Float32, y::Float32) = ccall("extern __nv_atan2f", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) -@inline atan(x::Float64, y::Float64) = atan2(x, y) -@inline atan(x::Float32, y::Float32) = atan2(x, y) +@device_override Base.atan(x::Float64, y::Float64) = ccall("extern __nv_atan2", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) +@device_override Base.atan(x::Float32, y::Float32) = ccall("extern __nv_atan2f", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) -@inline angle(x::ComplexF64) = atan2(x.im, x.re) -@inline angle(x::ComplexF32) = atan2(x.im, x.re) -@inline angle(x::Float64) = signbit(x,) * 3.141592653589793 -@inline angle(x::Float32) = signbit(x,) * 3.1415927f0 +@device_override Base.angle(x::Float64) = signbit(x,) * 3.141592653589793 +@device_override Base.angle(x::Float32) = signbit(x,) * 3.1415927f0 ## hyperbolic -@inline cosh(x::Float64) = ccall("extern __nv_cosh", llvmcall, Cdouble, (Cdouble,), x) -@inline cosh(x::Float32) = ccall("extern __nv_coshf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.cosh(x::Float64) = ccall("extern __nv_cosh", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.cosh(x::Float32) = ccall("extern __nv_coshf", llvmcall, Cfloat, (Cfloat,), x) -@inline sinh(x::Float64) = ccall("extern __nv_sinh", llvmcall, Cdouble, (Cdouble,), x) -@inline sinh(x::Float32) = ccall("extern __nv_sinhf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.sinh(x::Float64) = ccall("extern __nv_sinh", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.sinh(x::Float32) = ccall("extern __nv_sinhf", llvmcall, Cfloat, (Cfloat,), x) -@inline tanh(x::Float64) = ccall("extern __nv_tanh", llvmcall, Cdouble, (Cdouble,), x) -@inline tanh(x::Float32) = ccall("extern __nv_tanhf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.tanh(x::Float64) = ccall("extern __nv_tanh", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.tanh(x::Float32) = ccall("extern __nv_tanhf", llvmcall, Cfloat, (Cfloat,), x) ## inverse hyperbolic -@inline acosh(x::Float64) = ccall("extern __nv_acosh", llvmcall, Cdouble, (Cdouble,), x) -@inline acosh(x::Float32) = ccall("extern __nv_acoshf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.acosh(x::Float64) = ccall("extern __nv_acosh", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.acosh(x::Float32) = ccall("extern __nv_acoshf", llvmcall, Cfloat, (Cfloat,), x) -@inline asinh(x::Float64) = ccall("extern __nv_asinh", llvmcall, Cdouble, (Cdouble,), x) -@inline asinh(x::Float32) = ccall("extern __nv_asinhf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.asinh(x::Float64) = ccall("extern __nv_asinh", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.asinh(x::Float32) = ccall("extern __nv_asinhf", llvmcall, Cfloat, (Cfloat,), x) -@inline atanh(x::Float64) = ccall("extern __nv_atanh", llvmcall, Cdouble, (Cdouble,), x) -@inline atanh(x::Float32) = ccall("extern __nv_atanhf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.atanh(x::Float64) = ccall("extern __nv_atanh", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.atanh(x::Float32) = ccall("extern __nv_atanhf", llvmcall, Cfloat, (Cfloat,), x) ## logarithmic -@inline log(x::Float64) = ccall("extern __nv_log", llvmcall, Cdouble, (Cdouble,), x) -@inline log(x::Float32) = ccall("extern __nv_logf", llvmcall, Cfloat, (Cfloat,), x) -@inline log_fast(x::Float32) = ccall("extern __nv_fast_logf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.log(x::Float64) = ccall("extern __nv_log", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.log(x::Float32) = ccall("extern __nv_logf", llvmcall, Cfloat, (Cfloat,), x) +@device_override FastMath.log_fast(x::Float32) = ccall("extern __nv_fast_logf", llvmcall, Cfloat, (Cfloat,), x) -@inline log(x::ComplexF64) = log(abs(x,)) + im * angle(x,) -@inline log(x::ComplexF32) = log(abs(x,)) + im * angle(x,) -@inline log_fast(x::ComplexF32) = log_fast(abs(x,)) + im * angle(x,) +@device_override Base.log(x::ComplexF64) = log(abs(x,)) + im * angle(x,) +@device_override Base.log(x::ComplexF32) = log(abs(x,)) + im * angle(x,) +@device_override FastMath.log_fast(x::ComplexF32) = FastMath.log_fast(abs(x,)) + im * angle(x,) -@inline log10(x::Float64) = ccall("extern __nv_log10", llvmcall, Cdouble, (Cdouble,), x) -@inline log10(x::Float32) = ccall("extern __nv_log10f", llvmcall, Cfloat, (Cfloat,), x) -@inline log10_fast(x::Float32) = ccall("extern __nv_fast_log10f", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.log10(x::Float64) = ccall("extern __nv_log10", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.log10(x::Float32) = ccall("extern __nv_log10f", llvmcall, Cfloat, (Cfloat,), x) +@device_override FastMath.log10_fast(x::Float32) = ccall("extern __nv_fast_log10f", llvmcall, Cfloat, (Cfloat,), x) -@inline log1p(x::Float64) = ccall("extern __nv_log1p", llvmcall, Cdouble, (Cdouble,), x) -@inline log1p(x::Float32) = ccall("extern __nv_log1pf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.log1p(x::Float64) = ccall("extern __nv_log1p", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.log1p(x::Float32) = ccall("extern __nv_log1pf", llvmcall, Cfloat, (Cfloat,), x) -@inline log2(x::Float64) = ccall("extern __nv_log2", llvmcall, Cdouble, (Cdouble,), x) -@inline log2(x::Float32) = ccall("extern __nv_log2f", llvmcall, Cfloat, (Cfloat,), x) -@inline log2_fast(x::Float32) = ccall("extern __nv_fast_log2f", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.log2(x::Float64) = ccall("extern __nv_log2", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.log2(x::Float32) = ccall("extern __nv_log2f", llvmcall, Cfloat, (Cfloat,), x) +@device_override FastMath.log2_fast(x::Float32) = ccall("extern __nv_fast_log2f", llvmcall, Cfloat, (Cfloat,), x) -@inline logb(x::Float64) = ccall("extern __nv_logb", llvmcall, Cdouble, (Cdouble,), x) -@inline logb(x::Float32) = ccall("extern __nv_logbf", llvmcall, Cfloat, (Cfloat,), x) +@device_function logb(x::Float64) = ccall("extern __nv_logb", llvmcall, Cdouble, (Cdouble,), x) +@device_function logb(x::Float32) = ccall("extern __nv_logbf", llvmcall, Cfloat, (Cfloat,), x) -@inline ilogb(x::Float64) = ccall("extern __nv_ilogb", llvmcall, Int32, (Cdouble,), x) -@inline ilogb(x::Float32) = ccall("extern __nv_ilogbf", llvmcall, Int32, (Cfloat,), x) +@device_function ilogb(x::Float64) = ccall("extern __nv_ilogb", llvmcall, Int32, (Cdouble,), x) +@device_function ilogb(x::Float32) = ccall("extern __nv_ilogbf", llvmcall, Int32, (Cfloat,), x) ## exponential -@inline exp(x::Float64) = ccall("extern __nv_exp", llvmcall, Cdouble, (Cdouble,), x) -@inline exp(x::Float32) = ccall("extern __nv_expf", llvmcall, Cfloat, (Cfloat,), x) -@inline exp_fast(x::Float32) = ccall("extern __nv_fast_expf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.exp(x::Float64) = ccall("extern __nv_exp", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.exp(x::Float32) = ccall("extern __nv_expf", llvmcall, Cfloat, (Cfloat,), x) +@device_override FastMath.exp_fast(x::Float32) = ccall("extern __nv_fast_expf", llvmcall, Cfloat, (Cfloat,), x) -@inline exp2(x::Float64) = ccall("extern __nv_exp2", llvmcall, Cdouble, (Cdouble,), x) -@inline exp2(x::Float32) = ccall("extern __nv_exp2f", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.exp2(x::Float64) = ccall("extern __nv_exp2", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.exp2(x::Float32) = ccall("extern __nv_exp2f", llvmcall, Cfloat, (Cfloat,), x) -@inline exp10(x::Float64) = ccall("extern __nv_exp10", llvmcall, Cdouble, (Cdouble,), x) -@inline exp10(x::Float32) = ccall("extern __nv_exp10f", llvmcall, Cfloat, (Cfloat,), x) -@inline exp10_fast(x::Float32) = ccall("extern __nv_fast_exp10f", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.exp10(x::Float64) = ccall("extern __nv_exp10", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.exp10(x::Float32) = ccall("extern __nv_exp10f", llvmcall, Cfloat, (Cfloat,), x) +@device_override FastMath.exp10_fast(x::Float32) = ccall("extern __nv_fast_exp10f", llvmcall, Cfloat, (Cfloat,), x) -@inline expm1(x::Float64) = ccall("extern __nv_expm1", llvmcall, Cdouble, (Cdouble,), x) -@inline expm1(x::Float32) = ccall("extern __nv_expm1f", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.expm1(x::Float64) = ccall("extern __nv_expm1", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.expm1(x::Float32) = ccall("extern __nv_expm1f", llvmcall, Cfloat, (Cfloat,), x) -@inline ldexp(x::Float64, y::Int32) = ccall("extern __nv_ldexp", llvmcall, Cdouble, (Cdouble, Int32), x, y) -@inline ldexp(x::Float32, y::Int32) = ccall("extern __nv_ldexpf", llvmcall, Cfloat, (Cfloat, Int32), x, y) +@device_override Base.ldexp(x::Float64, y::Int32) = ccall("extern __nv_ldexp", llvmcall, Cdouble, (Cdouble, Int32), x, y) +@device_override Base.ldexp(x::Float32, y::Int32) = ccall("extern __nv_ldexpf", llvmcall, Cfloat, (Cfloat, Int32), x, y) -@inline exp(x::Complex{Float64}) = exp(x.re) * (cos(x.im) + 1.0im * sin(x.im)) -@inline exp(x::Complex{Float32}) = exp(x.re) * (cos(x.im) + 1.0f0im * sin(x.im)) -@inline exp_fast(x::Complex{Float32}) = exp_fast(x.re) * (cos_fast(x.im) + 1.0f0im * sin_fast(x.im)) +@device_override Base.exp(x::Complex{Float64}) = exp(x.re) * (cos(x.im) + 1.0im * sin(x.im)) +@device_override Base.exp(x::Complex{Float32}) = exp(x.re) * (cos(x.im) + 1.0f0im * sin(x.im)) +@device_override FastMath.exp_fast(x::Complex{Float32}) = FastMath.exp_fast(x.re) * (FastMath.cos_fast(x.im) + 1.0f0im * FastMath.sin_fast(x.im)) ## error -@inline erf(x::Float64) = ccall("extern __nv_erf", llvmcall, Cdouble, (Cdouble,), x) -@inline erf(x::Float32) = ccall("extern __nv_erff", llvmcall, Cfloat, (Cfloat,), x) +@device_override SpecialFunctions.erf(x::Float64) = ccall("extern __nv_erf", llvmcall, Cdouble, (Cdouble,), x) +@device_override SpecialFunctions.erf(x::Float32) = ccall("extern __nv_erff", llvmcall, Cfloat, (Cfloat,), x) -@inline erfinv(x::Float64) = ccall("extern __nv_erfinv", llvmcall, Cdouble, (Cdouble,), x) -@inline erfinv(x::Float32) = ccall("extern __nv_erfinvf", llvmcall, Cfloat, (Cfloat,), x) +@device_override SpecialFunctions.erfinv(x::Float64) = ccall("extern __nv_erfinv", llvmcall, Cdouble, (Cdouble,), x) +@device_override SpecialFunctions.erfinv(x::Float32) = ccall("extern __nv_erfinvf", llvmcall, Cfloat, (Cfloat,), x) -@inline erfc(x::Float64) = ccall("extern __nv_erfc", llvmcall, Cdouble, (Cdouble,), x) -@inline erfc(x::Float32) = ccall("extern __nv_erfcf", llvmcall, Cfloat, (Cfloat,), x) +@device_override SpecialFunctions.erfc(x::Float64) = ccall("extern __nv_erfc", llvmcall, Cdouble, (Cdouble,), x) +@device_override SpecialFunctions.erfc(x::Float32) = ccall("extern __nv_erfcf", llvmcall, Cfloat, (Cfloat,), x) -@inline erfcinv(x::Float64) = ccall("extern __nv_erfcinv", llvmcall, Cdouble, (Cdouble,), x) -@inline erfcinv(x::Float32) = ccall("extern __nv_erfcinvf", llvmcall, Cfloat, (Cfloat,), x) +@device_override SpecialFunctions.erfcinv(x::Float64) = ccall("extern __nv_erfcinv", llvmcall, Cdouble, (Cdouble,), x) +@device_override SpecialFunctions.erfcinv(x::Float32) = ccall("extern __nv_erfcinvf", llvmcall, Cfloat, (Cfloat,), x) -@inline erfcx(x::Float64) = ccall("extern __nv_erfcx", llvmcall, Cdouble, (Cdouble,), x) -@inline erfcx(x::Float32) = ccall("extern __nv_erfcxf", llvmcall, Cfloat, (Cfloat,), x) +@device_override SpecialFunctions.erfcx(x::Float64) = ccall("extern __nv_erfcx", llvmcall, Cdouble, (Cdouble,), x) +@device_override SpecialFunctions.erfcx(x::Float32) = ccall("extern __nv_erfcxf", llvmcall, Cfloat, (Cfloat,), x) ## integer handling (bit twiddling) -@inline brev(x::Int32) = ccall("extern __nv_brev", llvmcall, Int32, (Int32,), x) -@inline brev(x::Int64) = ccall("extern __nv_brevll", llvmcall, Int64, (Int64,), x) +@device_function brev(x::Int32) = ccall("extern __nv_brev", llvmcall, Int32, (Int32,), x) +@device_function brev(x::Int64) = ccall("extern __nv_brevll", llvmcall, Int64, (Int64,), x) -@inline clz(x::Int32) = ccall("extern __nv_clz", llvmcall, Int32, (Int32,), x) -@inline clz(x::Int64) = ccall("extern __nv_clzll", llvmcall, Int32, (Int64,), x) +@device_function clz(x::Int32) = ccall("extern __nv_clz", llvmcall, Int32, (Int32,), x) +@device_function clz(x::Int64) = ccall("extern __nv_clzll", llvmcall, Int32, (Int64,), x) -@inline ffs(x::Int32) = ccall("extern __nv_ffs", llvmcall, Int32, (Int32,), x) -@inline ffs(x::Int64) = ccall("extern __nv_ffsll", llvmcall, Int32, (Int64,), x) +@device_function ffs(x::Int32) = ccall("extern __nv_ffs", llvmcall, Int32, (Int32,), x) +@device_function ffs(x::Int64) = ccall("extern __nv_ffsll", llvmcall, Int32, (Int64,), x) -@inline byte_perm(x::Int32, y::Int32, z::Int32) = ccall("extern __nv_byte_perm", llvmcall, Int32, (Int32, Int32, Int32), x, y, z) +@device_function byte_perm(x::Int32, y::Int32, z::Int32) = ccall("extern __nv_byte_perm", llvmcall, Int32, (Int32, Int32, Int32), x, y, z) -@inline popc(x::Int32) = ccall("extern __nv_popc", llvmcall, Int32, (Int32,), x) -@inline popc(x::Int64) = ccall("extern __nv_popcll", llvmcall, Int32, (Int64,), x) +@device_function popc(x::Int32) = ccall("extern __nv_popc", llvmcall, Int32, (Int32,), x) +@device_function popc(x::Int64) = ccall("extern __nv_popcll", llvmcall, Int32, (Int64,), x) ## floating-point handling -@inline isfinite(x::Float32) = (ccall("extern __nv_finitef", llvmcall, Int32, (Cfloat,), x)) != 0 -@inline isfinite(x::Float64) = (ccall("extern __nv_isfinited", llvmcall, Int32, (Cdouble,), x)) != 0 +@device_override Base.isfinite(x::Float32) = (ccall("extern __nv_finitef", llvmcall, Int32, (Cfloat,), x)) != 0 +@device_override Base.isfinite(x::Float64) = (ccall("extern __nv_isfinited", llvmcall, Int32, (Cdouble,), x)) != 0 -@inline isinf(x::Float64) = (ccall("extern __nv_isinfd", llvmcall, Int32, (Cdouble,), x)) != 0 -@inline isinf(x::Float32) = (ccall("extern __nv_isinff", llvmcall, Int32, (Cfloat,), x)) != 0 +@device_override Base.isinf(x::Float64) = (ccall("extern __nv_isinfd", llvmcall, Int32, (Cdouble,), x)) != 0 +@device_override Base.isinf(x::Float32) = (ccall("extern __nv_isinff", llvmcall, Int32, (Cfloat,), x)) != 0 -@inline isnan(x::Float64) = (ccall("extern __nv_isnand", llvmcall, Int32, (Cdouble,), x)) != 0 -@inline isnan(x::Float32) = (ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x)) != 0 +@device_override Base.isnan(x::Float64) = (ccall("extern __nv_isnand", llvmcall, Int32, (Cdouble,), x)) != 0 +@device_override Base.isnan(x::Float32) = (ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x)) != 0 -@inline nearbyint(x::Float64) = ccall("extern __nv_nearbyint", llvmcall, Cdouble, (Cdouble,), x) -@inline nearbyint(x::Float32) = ccall("extern __nv_nearbyintf", llvmcall, Cfloat, (Cfloat,), x) +@device_function nearbyint(x::Float64) = ccall("extern __nv_nearbyint", llvmcall, Cdouble, (Cdouble,), x) +@device_function nearbyint(x::Float32) = ccall("extern __nv_nearbyintf", llvmcall, Cfloat, (Cfloat,), x) -@inline nextafter(x::Float64, y::Float64) = ccall("extern __nv_nextafter", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) -@inline nextafter(x::Float32, y::Float32) = ccall("extern __nv_nextafterf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +@device_function nextafter(x::Float64, y::Float64) = ccall("extern __nv_nextafter", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) +@device_function nextafter(x::Float32, y::Float32) = ccall("extern __nv_nextafterf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) ## sign handling -@inline signbit(x::Float64) = (ccall("extern __nv_signbitd", llvmcall, Int32, (Cdouble,), x)) != 0 -@inline signbit(x::Float32) = (ccall("extern __nv_signbitf", llvmcall, Int32, (Cfloat,), x)) != 0 +@device_override Base.signbit(x::Float64) = (ccall("extern __nv_signbitd", llvmcall, Int32, (Cdouble,), x)) != 0 +@device_override Base.signbit(x::Float32) = (ccall("extern __nv_signbitf", llvmcall, Int32, (Cfloat,), x)) != 0 -@inline copysign(x::Float64, y::Float64) = ccall("extern __nv_copysign", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) -@inline copysign(x::Float32, y::Float32) = ccall("extern __nv_copysignf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +@device_override Base.copysign(x::Float64, y::Float64) = ccall("extern __nv_copysign", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) +@device_override Base.copysign(x::Float32, y::Float32) = ccall("extern __nv_copysignf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) -@inline abs(x::Int32) = ccall("extern __nv_abs", llvmcall, Int32, (Int32,), x) -@inline abs(f::Float64) = ccall("extern __nv_fabs", llvmcall, Cdouble, (Cdouble,), f) -@inline abs(f::Float32) = ccall("extern __nv_fabsf", llvmcall, Cfloat, (Cfloat,), f) -@inline abs(x::Int64) = ccall("extern __nv_llabs", llvmcall, Int64, (Int64,), x) +@device_override Base.abs(x::Int32) = ccall("extern __nv_abs", llvmcall, Int32, (Int32,), x) +@device_override Base.abs(f::Float64) = ccall("extern __nv_fabs", llvmcall, Cdouble, (Cdouble,), f) +@device_override Base.abs(f::Float32) = ccall("extern __nv_fabsf", llvmcall, Cfloat, (Cfloat,), f) +@device_override Base.abs(x::Int64) = ccall("extern __nv_llabs", llvmcall, Int64, (Int64,), x) -@inline abs(x::Complex{Float64}) = hypot(x.re, x.im) -@inline abs(x::Complex{Float32}) = hypot(x.re, x.im) +@device_override Base.abs(x::Complex{Float64}) = hypot(x.re, x.im) +@device_override Base.abs(x::Complex{Float32}) = hypot(x.re, x.im) ## roots and powers -@inline sqrt(x::Float64) = ccall("extern __nv_sqrt", llvmcall, Cdouble, (Cdouble,), x) -@inline sqrt(x::Float32) = ccall("extern __nv_sqrtf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.sqrt(x::Float64) = ccall("extern __nv_sqrt", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.sqrt(x::Float32) = ccall("extern __nv_sqrtf", llvmcall, Cfloat, (Cfloat,), x) -@inline rsqrt(x::Float64) = ccall("extern __nv_rsqrt", llvmcall, Cdouble, (Cdouble,), x) -@inline rsqrt(x::Float32) = ccall("extern __nv_rsqrtf", llvmcall, Cfloat, (Cfloat,), x) +@device_function rsqrt(x::Float64) = ccall("extern __nv_rsqrt", llvmcall, Cdouble, (Cdouble,), x) +@device_function rsqrt(x::Float32) = ccall("extern __nv_rsqrtf", llvmcall, Cfloat, (Cfloat,), x) -@inline cbrt(x::Float64) = ccall("extern __nv_cbrt", llvmcall, Cdouble, (Cdouble,), x) -@inline cbrt(x::Float32) = ccall("extern __nv_cbrtf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.cbrt(x::Float64) = ccall("extern __nv_cbrt", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.cbrt(x::Float32) = ccall("extern __nv_cbrtf", llvmcall, Cfloat, (Cfloat,), x) -@inline rcbrt(x::Float64) = ccall("extern __nv_rcbrt", llvmcall, Cdouble, (Cdouble,), x) -@inline rcbrt(x::Float32) = ccall("extern __nv_rcbrtf", llvmcall, Cfloat, (Cfloat,), x) +@device_function rcbrt(x::Float64) = ccall("extern __nv_rcbrt", llvmcall, Cdouble, (Cdouble,), x) +@device_function rcbrt(x::Float32) = ccall("extern __nv_rcbrtf", llvmcall, Cfloat, (Cfloat,), x) -@inline pow(x::Float64, y::Float64) = ccall("extern __nv_pow", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) -@inline pow(x::Float32, y::Float32) = ccall("extern __nv_powf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) -@inline pow_fast(x::Float32, y::Float32) = ccall("extern __nv_fast_powf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) -@inline pow(x::Float64, y::Int32) = ccall("extern __nv_powi", llvmcall, Cdouble, (Cdouble, Int32), x, y) -@inline pow(x::Float32, y::Int32) = ccall("extern __nv_powif", llvmcall, Cfloat, (Cfloat, Int32), x, y) -@inline pow(x::Union{Float32, Float64}, y::Int64) = pow(x, Int32(y,)) - -@inline abs2(x::Complex{Float64}) = x.re * x.re + x.im * x.im -@inline abs2(x::Complex{Float32}) = x.re * x.re + x.im * x.im +@device_override Base.:(^)(x::Float64, y::Float64) = ccall("extern __nv_pow", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) +@device_override Base.:(^)(x::Float32, y::Float32) = ccall("extern __nv_powf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +@device_override FastMath.pow_fast(x::Float32, y::Float32) = ccall("extern __nv_fast_powf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +@device_override Base.:(^)(x::Float64, y::Int32) = ccall("extern __nv_powi", llvmcall, Cdouble, (Cdouble, Int32), x, y) +@device_override Base.:(^)(x::Float32, y::Int32) = ccall("extern __nv_powif", llvmcall, Cfloat, (Cfloat, Int32), x, y) +@device_override Base.:(^)(x::Float64, y::Int64) = x ^ Float64(y) +@device_override Base.:(^)(x::Float32, y::Int64) = x ^ Float32(y) ## rounding and selection # TODO: differentiate in return type, map correctly -# @inline round(x::Float64) = ccall("extern __nv_llround", llvmcall, Int64, (Cdouble,), x) -# @inline round(x::Float32) = ccall("extern __nv_llroundf", llvmcall, Int64, (Cfloat,), x) -# @inline round(x::Float64) = ccall("extern __nv_round", llvmcall, Cdouble, (Cdouble,), x) -# @inline round(x::Float32) = ccall("extern __nv_roundf", llvmcall, Cfloat, (Cfloat,), x) +#@device_override Base.round(x::Float64) = ccall("extern __nv_llround", llvmcall, Int64, (Cdouble,), x) +#@device_override Base.round(x::Float32) = ccall("extern __nv_llroundf", llvmcall, Int64, (Cfloat,), x) +#@device_override Base.round(x::Float64) = ccall("extern __nv_round", llvmcall, Cdouble, (Cdouble,), x) +#@device_override Base.round(x::Float32) = ccall("extern __nv_roundf", llvmcall, Cfloat, (Cfloat,), x) # TODO: differentiate in return type, map correctly -# @inline rint(x::Float64) = ccall("extern __nv_llrint", llvmcall, Int64, (Cdouble,), x) -# @inline rint(x::Float32) = ccall("extern __nv_llrintf", llvmcall, Int64, (Cfloat,), x) -# @inline rint(x::Float64) = ccall("extern __nv_rint", llvmcall, Cdouble, (Cdouble,), x) -# @inline rint(x::Float32) = ccall("extern __nv_rintf", llvmcall, Cfloat, (Cfloat,), x) +#@device_override Base.rint(x::Float64) = ccall("extern __nv_llrint", llvmcall, Int64, (Cdouble,), x) +#@device_override Base.rint(x::Float32) = ccall("extern __nv_llrintf", llvmcall, Int64, (Cfloat,), x) +#@device_override Base.rint(x::Float64) = ccall("extern __nv_rint", llvmcall, Cdouble, (Cdouble,), x) +#@device_override Base.rint(x::Float32) = ccall("extern __nv_rintf", llvmcall, Cfloat, (Cfloat,), x) -# TODO: would conflict with trunc usage in this module -# @inline trunc(x::Float64) = ccall("extern __nv_trunc", llvmcall, Cdouble, (Cdouble,), x) -# @inline trunc(x::Float32) = ccall("extern __nv_truncf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.trunc(x::Float64) = ccall("extern __nv_trunc", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.trunc(x::Float32) = ccall("extern __nv_truncf", llvmcall, Cfloat, (Cfloat,), x) -@inline ceil(x::Float64) = ccall("extern __nv_ceil", llvmcall, Cdouble, (Cdouble,), x) -@inline ceil(x::Float32) = ccall("extern __nv_ceilf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.ceil(x::Float64) = ccall("extern __nv_ceil", llvmcall, Cdouble, (Cdouble,), x) +@device_override Base.ceil(x::Float32) = ccall("extern __nv_ceilf", llvmcall, Cfloat, (Cfloat,), x) -@inline floor(f::Float64) = ccall("extern __nv_floor", llvmcall, Cdouble, (Cdouble,), f) -@inline floor(f::Float32) = ccall("extern __nv_floorf", llvmcall, Cfloat, (Cfloat,), f) +@device_override Base.floor(f::Float64) = ccall("extern __nv_floor", llvmcall, Cdouble, (Cdouble,), f) +@device_override Base.floor(f::Float32) = ccall("extern __nv_floorf", llvmcall, Cfloat, (Cfloat,), f) -@inline min(x::Int32, y::Int32) = ccall("extern __nv_min", llvmcall, Int32, (Int32, Int32), x, y) -@inline min(x::Int64, y::Int64) = ccall("extern __nv_llmin", llvmcall, Int64, (Int64, Int64), x, y) -@inline min(x::UInt32, y::UInt32) = convert(UInt32, ccall("extern __nv_umin", llvmcall, Int32, (Int32, Int32), x, y)) -@inline min(x::UInt64, y::UInt64) = convert(UInt64, ccall("extern __nv_ullmin", llvmcall, Int64, (Int64, Int64), x, y)) -@inline min(x::Float64, y::Float64) = ccall("extern __nv_fmin", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) -@inline min(x::Float32, y::Float32) = ccall("extern __nv_fminf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +#@device_override Base.min(x::Int32, y::Int32) = ccall("extern __nv_min", llvmcall, Int32, (Int32, Int32), x, y) +#@device_override Base.min(x::Int64, y::Int64) = ccall("extern __nv_llmin", llvmcall, Int64, (Int64, Int64), x, y) +#@device_override Base.min(x::UInt32, y::UInt32) = convert(UInt32, ccall("extern __nv_umin", llvmcall, Int32, (Int32, Int32), x, y)) +#@device_override Base.min(x::UInt64, y::UInt64) = convert(UInt64, ccall("extern __nv_ullmin", llvmcall, Int64, (Int64, Int64), x, y)) +@device_override Base.min(x::Float64, y::Float64) = ccall("extern __nv_fmin", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) +@device_override Base.min(x::Float32, y::Float32) = ccall("extern __nv_fminf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) -@inline max(x::Int32, y::Int32) = ccall("extern __nv_max", llvmcall, Int32, (Int32, Int32), x, y) -@inline max(x::Int64, y::Int64) = ccall("extern __nv_llmax", llvmcall, Int64, (Int64, Int64), x, y) -@inline max(x::UInt32, y::UInt32) = convert(UInt32, ccall("extern __nv_umax", llvmcall, Int32, (Int32, Int32), x, y)) -@inline max(x::UInt64, y::UInt64) = convert(UInt64, ccall("extern __nv_ullmax", llvmcall, Int64, (Int64, Int64), x, y)) -@inline max(x::Float64, y::Float64) = ccall("extern __nv_fmax", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) -@inline max(x::Float32, y::Float32) = ccall("extern __nv_fmaxf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +#@device_override Base.max(x::Int32, y::Int32) = ccall("extern __nv_max", llvmcall, Int32, (Int32, Int32), x, y) +#@device_override Base.max(x::Int64, y::Int64) = ccall("extern __nv_llmax", llvmcall, Int64, (Int64, Int64), x, y) +#@device_override Base.max(x::UInt32, y::UInt32) = convert(UInt32, ccall("extern __nv_umax", llvmcall, Int32, (Int32, Int32), x, y)) +#@device_override Base.max(x::UInt64, y::UInt64) = convert(UInt64, ccall("extern __nv_ullmax", llvmcall, Int64, (Int64, Int64), x, y)) +@device_override Base.max(x::Float64, y::Float64) = ccall("extern __nv_fmax", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) +@device_override Base.max(x::Float32, y::Float32) = ccall("extern __nv_fmaxf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) -@inline saturate(x::Float32) = ccall("extern __nv_saturatef", llvmcall, Cfloat, (Cfloat,), x) +@device_function saturate(x::Float32) = ccall("extern __nv_saturatef", llvmcall, Cfloat, (Cfloat,), x) ## division and remainder -@inline mod(x::Float64, y::Float64) = ccall("extern __nv_fmod", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) -@inline mod(x::Float32, y::Float32) = ccall("extern __nv_fmodf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +@device_override Base.mod(x::Float64, y::Float64) = ccall("extern __nv_fmod", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) +@device_override Base.mod(x::Float32, y::Float32) = ccall("extern __nv_fmodf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) -@inline rem(x::Float64, y::Float64) = ccall("extern __nv_remainder", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) -@inline rem(x::Float32, y::Float32) = ccall("extern __nv_remainderf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +@device_override Base.rem(x::Float64, y::Float64) = ccall("extern __nv_remainder", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) +@device_override Base.rem(x::Float32, y::Float32) = ccall("extern __nv_remainderf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) -@inline div_fast(x::Float32, y::Float32) = ccall("extern __nv_fast_fdividef", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +@device_override FastMath.div_fast(x::Float32, y::Float32) = ccall("extern __nv_fast_fdividef", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) ## gamma function -@inline lgamma(x::Float64) = ccall("extern __nv_lgamma", llvmcall, Cdouble, (Cdouble,), x) -@inline lgamma(x::Float32) = ccall("extern __nv_lgammaf", llvmcall, Cfloat, (Cfloat,), x) +@device_override SpecialFunctions.lgamma(x::Float64) = ccall("extern __nv_lgamma", llvmcall, Cdouble, (Cdouble,), x) +@device_override SpecialFunctions.lgamma(x::Float32) = ccall("extern __nv_lgammaf", llvmcall, Cfloat, (Cfloat,), x) -@inline tgamma(x::Float64) = ccall("extern __nv_tgamma", llvmcall, Cdouble, (Cdouble,), x) -@inline tgamma(x::Float32) = ccall("extern __nv_tgammaf", llvmcall, Cfloat, (Cfloat,), x) +@device_function tgamma(x::Float64) = ccall("extern __nv_tgamma", llvmcall, Cdouble, (Cdouble,), x) +@device_function tgamma(x::Float32) = ccall("extern __nv_tgammaf", llvmcall, Cfloat, (Cfloat,), x) ## Bessel -@inline j0(x::Float64) = ccall("extern __nv_j0", llvmcall, Cdouble, (Cdouble,), x) -@inline j0(x::Float32) = ccall("extern __nv_j0f", llvmcall, Cfloat, (Cfloat,), x) +@device_override SpecialFunctions.besselj0(x::Float64) = ccall("extern __nv_j0", llvmcall, Cdouble, (Cdouble,), x) +@device_override SpecialFunctions.besselj0(x::Float32) = ccall("extern __nv_j0f", llvmcall, Cfloat, (Cfloat,), x) -@inline j1(x::Float64) = ccall("extern __nv_j1", llvmcall, Cdouble, (Cdouble,), x) -@inline j1(x::Float32) = ccall("extern __nv_j1f", llvmcall, Cfloat, (Cfloat,), x) +@device_override SpecialFunctions.besselj1(x::Float64) = ccall("extern __nv_j1", llvmcall, Cdouble, (Cdouble,), x) +@device_override SpecialFunctions.besselj1(x::Float32) = ccall("extern __nv_j1f", llvmcall, Cfloat, (Cfloat,), x) -@inline jn(n::Int32, x::Float64) = ccall("extern __nv_jn", llvmcall, Cdouble, (Int32, Cdouble), n, x) -@inline jn(n::Int32, x::Float32) = ccall("extern __nv_jnf", llvmcall, Cfloat, (Int32, Cfloat), n, x) +@device_override SpecialFunctions.besselj(n::Int32, x::Float64) = ccall("extern __nv_jn", llvmcall, Cdouble, (Int32, Cdouble), n, x) +@device_override SpecialFunctions.besselj(n::Int32, x::Float32) = ccall("extern __nv_jnf", llvmcall, Cfloat, (Int32, Cfloat), n, x) -@inline y0(x::Float64) = ccall("extern __nv_y0", llvmcall, Cdouble, (Cdouble,), x) -@inline y0(x::Float32) = ccall("extern __nv_y0f", llvmcall, Cfloat, (Cfloat,), x) +@device_override SpecialFunctions.bessely0(x::Float64) = ccall("extern __nv_y0", llvmcall, Cdouble, (Cdouble,), x) +@device_override SpecialFunctions.bessely0(x::Float32) = ccall("extern __nv_y0f", llvmcall, Cfloat, (Cfloat,), x) -@inline y1(x::Float64) = ccall("extern __nv_y1", llvmcall, Cdouble, (Cdouble,), x) -@inline y1(x::Float32) = ccall("extern __nv_y1f", llvmcall, Cfloat, (Cfloat,), x) +@device_override SpecialFunctions.bessely1(x::Float64) = ccall("extern __nv_y1", llvmcall, Cdouble, (Cdouble,), x) +@device_override SpecialFunctions.bessely1(x::Float32) = ccall("extern __nv_y1f", llvmcall, Cfloat, (Cfloat,), x) -@inline yn(n::Int32, x::Float64) = ccall("extern __nv_yn", llvmcall, Cdouble, (Int32, Cdouble), n, x) -@inline yn(n::Int32, x::Float32) = ccall("extern __nv_ynf", llvmcall, Cfloat, (Int32, Cfloat), n, x) +@device_override SpecialFunctions.bessely(n::Int32, x::Float64) = ccall("extern __nv_yn", llvmcall, Cdouble, (Int32, Cdouble), n, x) +@device_override SpecialFunctions.bessely(n::Int32, x::Float32) = ccall("extern __nv_ynf", llvmcall, Cfloat, (Int32, Cfloat), n, x) ## distributions -@inline normcdf(x::Float64) = ccall("extern __nv_normcdf", llvmcall, Cdouble, (Cdouble,), x) -@inline normcdf(x::Float32) = ccall("extern __nv_normcdff", llvmcall, Cfloat, (Cfloat,), x) +# TODO: override StatsFun.jl? + +@device_function normcdf(x::Float64) = ccall("extern __nv_normcdf", llvmcall, Cdouble, (Cdouble,), x) +@device_function normcdf(x::Float32) = ccall("extern __nv_normcdff", llvmcall, Cfloat, (Cfloat,), x) -@inline normcdfinv(x::Float64) = ccall("extern __nv_normcdfinv", llvmcall, Cdouble, (Cdouble,), x) -@inline normcdfinv(x::Float32) = ccall("extern __nv_normcdfinvf", llvmcall, Cfloat, (Cfloat,), x) +@device_function normcdfinv(x::Float64) = ccall("extern __nv_normcdfinv", llvmcall, Cdouble, (Cdouble,), x) +@device_function normcdfinv(x::Float32) = ccall("extern __nv_normcdfinvf", llvmcall, Cfloat, (Cfloat,), x) @@ -307,31 +306,31 @@ # Unsorted # -@inline hypot(x::Float64, y::Float64) = ccall("extern __nv_hypot", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) -@inline hypot(x::Float32, y::Float32) = ccall("extern __nv_hypotf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +@device_override Base.hypot(x::Float64, y::Float64) = ccall("extern __nv_hypot", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) +@device_override Base.hypot(x::Float32, y::Float32) = ccall("extern __nv_hypotf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) -@inline fma(x::Float64, y::Float64, z::Float64) = ccall("extern __nv_fma", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) -@inline fma(x::Float32, y::Float32, z::Float32) = ccall("extern __nv_fmaf", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) +@device_override Base.fma(x::Float64, y::Float64, z::Float64) = ccall("extern __nv_fma", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) +@device_override Base.fma(x::Float32, y::Float32, z::Float32) = ccall("extern __nv_fmaf", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) -@inline sad(x::Int32, y::Int32, z::Int32) = ccall("extern __nv_sad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z) -@inline sad(x::UInt32, y::UInt32, z::UInt32) = convert(UInt32, ccall("extern __nv_usad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)) +@device_function sad(x::Int32, y::Int32, z::Int32) = ccall("extern __nv_sad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z) +@device_function sad(x::UInt32, y::UInt32, z::UInt32) = convert(UInt32, ccall("extern __nv_usad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)) -@inline dim(x::Float64, y::Float64) = ccall("extern __nv_fdim", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) -@inline dim(x::Float32, y::Float32) = ccall("extern __nv_fdimf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +@device_function dim(x::Float64, y::Float64) = ccall("extern __nv_fdim", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) +@device_function dim(x::Float32, y::Float32) = ccall("extern __nv_fdimf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) -@inline mul24(x::Int32, y::Int32) = ccall("extern __nv_mul24", llvmcall, Int32, (Int32, Int32), x, y) -@inline mul24(x::UInt32, y::UInt32) = convert(UInt32, ccall("extern __nv_umul24", llvmcall, Int32, (Int32, Int32), x, y)) +@device_function mul24(x::Int32, y::Int32) = ccall("extern __nv_mul24", llvmcall, Int32, (Int32, Int32), x, y) +@device_function mul24(x::UInt32, y::UInt32) = convert(UInt32, ccall("extern __nv_umul24", llvmcall, Int32, (Int32, Int32), x, y)) -@inline mul64hi(x::Int64, y::Int64) = ccall("extern __nv_mul64hi", llvmcall, Int64, (Int64, Int64), x, y) -@inline mul64hi(x::UInt64, y::UInt64) = convert(UInt64, ccall("extern __nv_umul64hi", llvmcall, Int64, (Int64, Int64), x, y)) -@inline mulhi(x::Int32, y::Int32) = ccall("extern __nv_mulhi", llvmcall, Int32, (Int32, Int32), x, y) -@inline mulhi(x::UInt32, y::UInt32) = convert(UInt32, ccall("extern __nv_umulhi", llvmcall, Int32, (Int32, Int32), x, y)) +@device_function mul64hi(x::Int64, y::Int64) = ccall("extern __nv_mul64hi", llvmcall, Int64, (Int64, Int64), x, y) +@device_function mul64hi(x::UInt64, y::UInt64) = convert(UInt64, ccall("extern __nv_umul64hi", llvmcall, Int64, (Int64, Int64), x, y)) +@device_function mulhi(x::Int32, y::Int32) = ccall("extern __nv_mulhi", llvmcall, Int32, (Int32, Int32), x, y) +@device_function mulhi(x::UInt32, y::UInt32) = convert(UInt32, ccall("extern __nv_umulhi", llvmcall, Int32, (Int32, Int32), x, y)) -@inline hadd(x::Int32, y::Int32) = ccall("extern __nv_hadd", llvmcall, Int32, (Int32, Int32), x, y) -@inline hadd(x::UInt32, y::UInt32) = convert(UInt32, ccall("extern __nv_uhadd", llvmcall, Int32, (Int32, Int32), x, y)) +@device_function hadd(x::Int32, y::Int32) = ccall("extern __nv_hadd", llvmcall, Int32, (Int32, Int32), x, y) +@device_function hadd(x::UInt32, y::UInt32) = convert(UInt32, ccall("extern __nv_uhadd", llvmcall, Int32, (Int32, Int32), x, y)) -@inline rhadd(x::Int32, y::Int32) = ccall("extern __nv_rhadd", llvmcall, Int32, (Int32, Int32), x, y) -@inline rhadd(x::UInt32, y::UInt32) = convert(UInt32, ccall("extern __nv_urhadd", llvmcall, Int32, (Int32, Int32), x, y)) +@device_function rhadd(x::Int32, y::Int32) = ccall("extern __nv_rhadd", llvmcall, Int32, (Int32, Int32), x, y) +@device_function rhadd(x::UInt32, y::UInt32) = convert(UInt32, ccall("extern __nv_urhadd", llvmcall, Int32, (Int32, Int32), x, y)) -@inline scalbn(x::Float64, y::Int32) = ccall("extern __nv_scalbn", llvmcall, Cdouble, (Cdouble, Int32), x, y) -@inline scalbn(x::Float32, y::Int32) = ccall("extern __nv_scalbnf", llvmcall, Cfloat, (Cfloat, Int32), x, y) +@device_function scalbn(x::Float64, y::Int32) = ccall("extern __nv_scalbn", llvmcall, Cdouble, (Cdouble, Int32), x, y) +@device_function scalbn(x::Float32, y::Int32) = ccall("extern __nv_scalbnf", llvmcall, Cfloat, (Cfloat, Int32), x, y) diff --git a/src/forwarddiff.jl b/src/forwarddiff.jl deleted file mode 100644 index 093dc34ece..0000000000 --- a/src/forwarddiff.jl +++ /dev/null @@ -1,87 +0,0 @@ -# ForwardDiff integration - -byhand = [:exp2, :log2, :exp10, :log10, :abs] - -for f in device_intrinsics - if haskey(ForwardDiff.DiffRules.DEFINED_DIFFRULES, (:Base,f,1)) - f ∈ byhand && continue - diffrule = ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:Base,f,1)] - ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDA,f,1)] = - (args...) -> replace_device(diffrule(args...)) - eval(ForwardDiff.unary_dual_definition(:CUDA, f)) - end -end - -# byhand: exp2 -ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDA, :exp2, 1)] = x -> - :((CUDA.cufunc(exp2))(x) * (CUDA.cufunc(log))(oftype(x, 2))) -eval(ForwardDiff.unary_dual_definition(:CUDA, :exp2)) - -# byhand: log2 -ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDA, :log2, 1)] = x -> - :(inv(x) / (CUDA.cufunc(log))(oftype(x, 2))) -eval(ForwardDiff.unary_dual_definition(:CUDA, :log2)) - -# byhand: exp10 -ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDA, :exp10, 1)] = x -> - :((CUDA.cufunc(exp10))(x) * (CUDA.cufunc(log))(oftype(x, 10))) -eval(ForwardDiff.unary_dual_definition(:CUDA, :exp10)) - -# byhand: log10 -ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDA, :log10, 1)] = x -> - :(inv(x) / (CUDA.cufunc(log))(oftype(x, 10))) -eval(ForwardDiff.unary_dual_definition(:CUDA, :log10)) - -# byhand: abs -ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDA, :abs, 1)] = x -> - :(signbit(x) ? -one(x) : one(x)) -eval(ForwardDiff.unary_dual_definition(:CUDA, :abs)) - - -ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDA, :pow, 2)] = (x, y) -> - replace_device.(ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:Base, :^, 2)](x, y)) - -@eval begin - ForwardDiff.@define_binary_dual_op( - CUDA.pow, - begin - vx = ForwardDiff.value(x) - vy = ForwardDiff.value(y) - expv = (CUDA.pow)(vx, vy) - - powval = vy * CUDA.pow(vx , vy - Int32(1)) - - py = ForwardDiff.partials(y) - px = ForwardDiff.partials(x) - - cond = all(py.values) do x - x == zero(x) - end - - if cond - logval = one(expv) - else - logval = expv * CUDA.log(vx) - end - - new_partials = powval * px + logval * py - return ForwardDiff.Dual{Txy}(expv, new_partials) - end, - begin - v = ForwardDiff.value(x) - expv = (CUDA.pow)(v, y) - if y == zero(y) - new_partials = zero(ForwardDiff.partials(x)) - else - new_partials = ForwardDiff.partials(x) * y * (CUDA.pow)(v, y - Int32(1)) - end - return ForwardDiff.Dual{Tx}(expv, new_partials) - end, - begin - v = ForwardDiff.value(y) - expv = (CUDA.pow)(x, v) - deriv = expv*CUDA.log(x) - return ForwardDiff.Dual{Ty}(expv, deriv * ForwardDiff.partials(y)) - end - ) -end diff --git a/src/initialization.jl b/src/initialization.jl index f202a6130f..afbebe3f21 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -91,7 +91,10 @@ function __init__() thread_streams[tid] = nothing end - @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl") + precompiling = ccall(:jl_generating_output, Cint, ()) != 0 + if !precompiling + eval(overrides) + end end function __runtime_init__() diff --git a/src/mapreduce.jl b/src/mapreduce.jl index c85c646545..fcf8b1ef2a 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -143,9 +143,6 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::AnyCuArray{T}, Base.check_reducedims(R, A) length(A) == 0 && return R # isempty(::Broadcasted) iterates - f = cufunc(f) - op = cufunc(op) - # be conservative about using shuffle instructions shuffle = T <: Union{Bool, Int32, Int64, Float32, Float64, ComplexF32, ComplexF64} diff --git a/src/sorting.jl b/src/sorting.jl index da753f9332..9bd497ce1e 100644 --- a/src/sorting.jl +++ b/src/sorting.jl @@ -401,9 +401,6 @@ function quicksort!(c::AbstractArray{T,N}; lt::F1, by::F2, dims::Int) where {T,N max_depth = CUDA.limit(CUDA.LIMIT_DEV_RUNTIME_SYNC_DEPTH) len = size(c, dims) - lt = CUDA.cufunc(lt) - by = CUDA.cufunc(by) - 1 <= dims <= N || throw(ArgumentError("dimension out of range")) otherdims = ntuple(i -> i == dims ? 1 : size(c, i), N) diff --git a/test/Project.toml b/test/Project.toml index 0b5469f953..5a81283cee 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,7 +6,6 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/test/broadcast.jl b/test/broadcast.jl index e01d7fe760..d2f7048b03 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -22,29 +22,6 @@ end @test Array(Whatever{Int}.(CuArray([1]))) == Whatever{Int}.([1]) end - -@testset "cufunc" begin - gelu1(x) = oftype(x, 0.5) * x * (1 + tanh(oftype(x, √(2/π))*(x + oftype(x, 0.044715) * x^3))) - sig(x) = one(x) / (one(x) + exp(-x)) - f(x) = gelu1(log(x)) * sig(x) * tanh(x) - g(x) = x^7 - 2 * x^f(x^2) + 3 - - CUDA.@cufunc gelu1(x) = oftype(x, 0.5) * x * (1 + tanh(oftype(x, √(2/π))*(x + oftype(x, 0.044715) * x^3))) - CUDA.@cufunc sig(x) = one(x) / (one(x) + exp(-x)) - CUDA.@cufunc f(x) = gelu1(log(x)) * sig(x) * tanh(x) - CUDA.@cufunc g(x) = x^7 - 2 * x^f(x^2) + 3 - - @test :gelu1 ∈ CUDA.cufuncs() - @test :sig ∈ CUDA.cufuncs() - @test :f ∈ CUDA.cufuncs() - @test :g ∈ CUDA.cufuncs() - - @test testf(x -> gelu1.(x), rand(3,3)) - @test testf(x -> sig.(x), rand(3,3)) - @test testf(x -> f.(x), rand(3,3)) - @test testf(x -> g.(x), rand(3,3)) -end - # https://github.com/JuliaGPU/CUDA.jl/issues/223 @testset "Ref Broadcast" begin foobar(idx, A) = A[idx] diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl index 49bcb3eacd..0a5e2f79a5 100644 --- a/test/device/intrinsics.jl +++ b/test/device/intrinsics.jl @@ -43,7 +43,7 @@ end buf = CuArray(zeros(Float32)) function pow_kernel(a, x, y) - a[] = CUDA.pow(x, y) + a[] = x^y return end diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl deleted file mode 100644 index 6e1d212e67..0000000000 --- a/test/forwarddiff.jl +++ /dev/null @@ -1,72 +0,0 @@ -using ForwardDiff -using ForwardDiff: Dual - -function test_derivative(f, x::T) where T - buf = CuArray(zeros(T)) - - function kernel(a, x) - a[] = ForwardDiff.derivative(f, x) - return - end - @cuda kernel(buf, x) - return CUDA.@allowscalar buf[] -end - -testdiff(cuf, f, x) = test_derivative(cuf, x) ≈ ForwardDiff.derivative(f, x) - - -@testset "UNARY" begin - fs = filter(x->x[1] ==:CUDA && x[3] == 1, keys(ForwardDiff.DiffRules.DEFINED_DIFFRULES)) - - - nonneg = [:log, :log1p, :log10, :log2, :sqrt, :acosh] - - for (m, fn, _) ∈ fs - cuf = @eval $m.$fn - f = @eval $fn - - x32 = rand(Float32) - x64 = rand(Float64) - nx32 = -x32 - nx64 = -x64 - - if fn == :acosh - x32 += 1 - x64 += 1 - end - - @test testdiff(cuf, f, x32) - @test testdiff(cuf, f, x64) - - if fn ∉ nonneg - @test testdiff(cuf, f, nx32) - @test testdiff(cuf, f, nx64) - end - end -end - -@testset "POW" begin - x32 = rand(Float32) - x64 = rand(Float64) - y32 = rand(Float32) - y64 = rand(Float64) - y = Int32(7) - - @test testdiff(x->CUDA.pow(x, Int32(7)), x->x^y, x32) - @test testdiff(x->CUDA.pow(x, y), x->x^y, x64) - @test testdiff(x->CUDA.pow(x, y32), x->x^y32, x32) - @test testdiff(x->CUDA.pow(x, y64), x->x^y64, x64) - - @test testdiff(y->CUDA.pow(x32, y), y->x32^y, y32) - @test testdiff(y->CUDA.pow(x64, y), y->x64^y, y64) - - @test testdiff(x->CUDA.pow(x, x), x->x^x, x32) - @test testdiff(x->CUDA.pow(x, x), x->x^x, x64) -end - -@testset "LITERAL_POW" begin - for x in [rand(Float32, 10), rand(Float64, 10)], - p in [1, 2, 3, 4, 5] - @test ForwardDiff.gradient(_x -> sum(_x .^ p), x) ≈ p .* (x .^ (p - 1)) - end -end From bcf1b82130a738bedddb3cfbec9b528b9913dfd0 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 17 Mar 2021 08:54:02 +0100 Subject: [PATCH 2/3] Clean-up version checks. --- deps/compatibility.jl | 5 ----- examples/wmma/high-level.jl | 6 ------ examples/wmma/low-level.jl | 6 ------ src/CUDA.jl | 2 +- src/array.jl | 6 ------ src/compiler/execution.jl | 6 +----- src/device/intrinsics.jl | 2 +- src/mapreduce.jl | 6 ------ test/exceptions.jl | 6 +----- test/execution.jl | 2 -- test/runtests.jl | 2 +- 11 files changed, 5 insertions(+), 44 deletions(-) diff --git a/deps/compatibility.jl b/deps/compatibility.jl index a2aefee90c..b4dfd146d5 100644 --- a/deps/compatibility.jl +++ b/deps/compatibility.jl @@ -175,11 +175,6 @@ end ## high-level functions that return target and isa support function llvm_compat(version=LLVM.version()) - # https://github.com/JuliaGPU/CUDAnative.jl/issues/428 - if version >= v"8.0" && VERSION < v"1.3.0-DEV.547" - error("LLVM 8.0 requires a newer version of Julia") - end - InitializeNVPTXTarget() cap_support = sort(collect(llvm_cap_support(version))) diff --git a/examples/wmma/high-level.jl b/examples/wmma/high-level.jl index 98dd28fc8d..bfc1c46d30 100644 --- a/examples/wmma/high-level.jl +++ b/examples/wmma/high-level.jl @@ -1,9 +1,3 @@ -# Need https://github.com/JuliaLang/julia/pull/33970 -# and https://github.com/JuliaLang/julia/pull/34043 -if VERSION < v"1.5-" - exit() -end - using CUDA if capability(device()) < v"7.0" exit() diff --git a/examples/wmma/low-level.jl b/examples/wmma/low-level.jl index 0424bbae9f..fe41512e2e 100644 --- a/examples/wmma/low-level.jl +++ b/examples/wmma/low-level.jl @@ -1,9 +1,3 @@ -# Need https://github.com/JuliaLang/julia/pull/33970 -# and https://github.com/JuliaLang/julia/pull/34043 -if VERSION < v"1.5-" - exit() -end - using CUDA if capability(device()) < v"7.0" exit() diff --git a/src/CUDA.jl b/src/CUDA.jl index f51f125c8f..34d43fef3b 100644 --- a/src/CUDA.jl +++ b/src/CUDA.jl @@ -25,7 +25,7 @@ using ExprTools const ci_cache = GPUCompiler.CodeCache() -@static if VERSION >= v"1.7-" +@static if isdefined(Base.Experimental, Symbol("@overlay")) Base.Experimental.@MethodTable(method_table) else const method_table = nothing diff --git a/src/array.jl b/src/array.jl index 6dcf41427c..57e20abd98 100644 --- a/src/array.jl +++ b/src/array.jl @@ -476,12 +476,6 @@ function Base.reshape(a::CuArray{T,M}, dims::NTuple{N,Int}) where {T,N,M} return b end -# allow missing dimensions with Colon() -if VERSION < v"1.6.0-DEV.1358" -Base.reshape(parent::CuArray, dims::Tuple{Vararg{Union{Int,Colon}}}) = - Base.reshape(parent, Base._reshape_uncolon(parent, dims)) -end - ## reinterpret diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index cbb9f7f550..e136e5d4c4 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -166,11 +166,7 @@ AbstractKernel args = (:F, (:( args[$i] ) for i in 1:length(args))...) # filter out arguments that shouldn't be passed - predicate = if VERSION >= v"1.5.0-DEV.581" - dt -> isghosttype(dt) || Core.Compiler.isconstType(dt) - else - dt -> isghosttype(dt) - end + predicate = dt -> isghosttype(dt) || Core.Compiler.isconstType(dt) to_pass = map(!predicate, sig.parameters) call_t = Type[x[1] for x in zip(sig.parameters, to_pass) if x[2]] call_args = Union{Expr,Symbol}[x[1] for x in zip(args, to_pass) if x[2]] diff --git a/src/device/intrinsics.jl b/src/device/intrinsics.jl index 18149ae6e7..0e0814b571 100644 --- a/src/device/intrinsics.jl +++ b/src/device/intrinsics.jl @@ -6,7 +6,7 @@ macro device_override(ex) code = quote $GPUCompiler.@override($method_table, $ex) end - if VERSION >= v"1.7-" + if isdefined(Base.Experimental, Symbol("@overlay")) return esc(code) else push!(overrides.args, code) diff --git a/src/mapreduce.jl b/src/mapreduce.jl index fcf8b1ef2a..8965f26b52 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -131,12 +131,6 @@ end ## COV_EXCL_STOP -if VERSION < v"1.5.0-DEV.748" - Base.axes(bc::Base.Broadcast.Broadcasted{<:CuArrayStyle, <:NTuple{N}}, - d::Integer) where N = - d <= N ? axes(bc)[d] : Base.OneTo(1) -end - function GPUArrays.mapreducedim!(f::F, op::OP, R::AnyCuArray{T}, A::Union{AbstractArray,Broadcast.Broadcasted}; init=nothing) where {F, OP, T} diff --git a/test/exceptions.jl b/test/exceptions.jl index 5198c93478..ed8d4bd8b1 100644 --- a/test/exceptions.jl +++ b/test/exceptions.jl @@ -54,11 +54,7 @@ let (code, out, err) = julia_script(script, `-g2`) occursin("ERROR: CUDA error: an illegal instruction was encountered", err) || occursin("ERROR: CUDA error: unspecified launch failure", err) @test occursin(r"ERROR: a \w+ was thrown during kernel execution", out) - if VERSION < v"1.3.0-DEV.270" - @test occursin("[1] Type at float.jl", out) - else - @test occursin("[1] Int64 at float.jl", out) - end + @test occursin("[1] Int64 at float.jl", out) @test occursin("[4] kernel at none:5", out) end diff --git a/test/execution.jl b/test/execution.jl index 093f6f7b6c..9b7a63ddaa 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -854,7 +854,6 @@ end @test out == "Hello, World!" end -if VERSION >= v"1.1" # behavior of captured variables (box or not) has improved over time @testset "closures" begin function hello() x = 1 @@ -869,7 +868,6 @@ if VERSION >= v"1.1" # behavior of captured variables (box or not) has improved end @test out == "Hello, World 1!" end -end @testset "argument passing" begin ## padding diff --git a/test/runtests.jl b/test/runtests.jl index f492e9deba..3206b346d8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -163,7 +163,7 @@ if !has_cutensor() || CUDA.version() < v"10.1" || first(picks).cap < v"7.0" push!(skip_tests, "cutensor") end is_debug = ccall(:jl_is_debugbuild, Cint, ()) != 0 -if VERSION < v"1.5-" || first(picks).cap < v"7.0" +if first(picks).cap < v"7.0" push!(skip_tests, "device/wmma") end if Sys.ARCH == :aarch64 From 90b93660c19024e9c3c44490a3e0d3ff94d89a5d Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 17 Mar 2021 14:53:04 +0100 Subject: [PATCH 3/3] Add tests. --- test/execution.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/execution.jl b/test/execution.jl index 9b7a63ddaa..c21fcbf28f 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -1059,3 +1059,24 @@ end end ############################################################################################ + +@testset "contextual dispatch" begin + +@test_throws ErrorException CUDA.saturate(1f0) # CUDA.jl#60 + +@test testf(a->broadcast(x->x^1.5, a), rand(Float32, 1)) # CUDA.jl#71 +@test testf(a->broadcast(x->1.0^x, a), rand(Int, 1)) # CUDA.jl#76 +@test testf(a->broadcast(x->x^4, a), rand(Float32, 1)) # CUDA.jl#171 + +@test argmax(cu([true false; false true])) == CartesianIndex(1, 1) # CUDA.jl#659 + +# CUDA.jl#42 +@test testf([Complex(1f0,2f0)]) do a + b = sincos.(a) + s,c = first(collect(b)) + (real(s), imag(s), real(c), imag(c)) +end + +end + +############################################################################################