diff --git a/docs/make.jl b/docs/make.jl index 3cbd62d523..b75e790825 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -60,6 +60,9 @@ function main() "development/troubleshooting.md", "development/debugging.md", ], + "Hacking" => Any[ + "hacking/exposing_new_intrinsics.md", + ], "API reference" => Any[ "api/essentials.md", "api/array.md", diff --git a/docs/src/hacking/exposing_new_intrinsics.jl b/docs/src/hacking/exposing_new_intrinsics.jl new file mode 100644 index 0000000000..662940af27 --- /dev/null +++ b/docs/src/hacking/exposing_new_intrinsics.jl @@ -0,0 +1,49 @@ +# # Introduction + +# * Adding new GPU intrinsics * + +# In this tutorial we will expose some GPU intrinsics to allow directed rounding in fused-multiply-add (fma) +# floating point operation +# We start by identifying the intrinsic we want to expose; to do so, we read the PTX (Parallel Thread Execution) +# documentation at [PTX - Floating Point Instructions](https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions). +# In table 32, it is presented a summary of floating point operations: we can construct the intrinsic string from that. +# The FMA instruction for Float32 is presented as `{mad,fma}.rnd.f32`, where `rnd` can assume the values `.rnd = { .rn, .rz, .rm, .rp }`, +# where `rn` is round to nearest, `rz` round to zero, `rm` round to minus infinity, `rp` round to plus infinity. +# When building the intrinsic for the call, we need to change the type `.f64` with `.d` and `.f32` with `.f` +# Therefore, to call the rounded towards infinity `fma` for `.f64` we need to call the intrinsic `llvm.nvvm.fma.rp.d` +# Please remark that this is only possible if LLVM support the intrinsic; a source for those exposed by LLVM +# may be found by searching the [LLVM repository](https://github.com/llvm/llvm-project). In in other cases you'd need @asmcall and inline PTX assembly. + +fma_rp(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rp.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) +fma(x::T, y::T, z::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = fma_rp(x, y, z) + +# We inspect the PTX code +CUDA.code_ptx(fma_rp, Tuple{Float64,Float64,Float64}) + +# It is possible to see that the PTX code contains a call to the intrinsic `fma.rp.f64`; we add this function now +# to src/device/intrins/math.jl + +using CUDA +function test_fma!(out, x, y) + I = threadIdx().x + z = (2.0) ^ (-(I+53)) + + out[I] = fma(x, y, z, RoundNearest) + out[I+4] = fma(x, y, z, RoundToZero) + out[I+8] = fma(x, y, z, RoundUp) + out[I+12] = fma(x, y, z, RoundDown) + + return +end + +# The first four entries of the output are Rounded to Nearest, the entries 5 to 8 are rounded towards zero, +# etc... + +out_d = CuArray(zeros(16)) +@cuda threads = 4 test_fma!(out_d, 1.0, 1.0) +out_h = Array(out_d) + +out_d = CuArray(zeros(4)) +@cuda threads = 4 test_fma!(out_d, -1.0, 1.0) +out_h = Array(out_d) + diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index a1d589721d..bc9f82eb96 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -390,8 +390,6 @@ end @device_function normcdfinv(x::Float64) = ccall("extern __nv_normcdfinv", llvmcall, Cdouble, (Cdouble,), x) @device_function normcdfinv(x::Float32) = ccall("extern __nv_normcdfinvf", llvmcall, Cfloat, (Cfloat,), x) - - # # Unsorted # @@ -399,9 +397,70 @@ end @device_override Base.hypot(x::Float64, y::Float64) = ccall("extern __nv_hypot", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) @device_override Base.hypot(x::Float32, y::Float32) = ccall("extern __nv_hypotf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) + +for type in [:f, :d] + for round in [:rn, :rz, :rm, :rp] + for op in [:add, :mul, :div] + + inp_type = Symbol("Float64") + c_type = Symbol("Cdouble") + if type == :f + inp_type = Symbol("Float32") + c_type = Symbol("Cfloat") + end + + func_name = Symbol("$(op)_$(round)") + intrinsic_name = "llvm.nvvm.$(op).$(round).$(type)" + #@info func_name, intrinsic_name + + @eval @device_function $func_name(x::$inp_type, y::$inp_type) = ccall($intrinsic_name, llvmcall, $c_type, ($c_type, $c_type), x, y) + end + end +end + +@device_function sub_rn(x, y) = add_rn(x, -y) +@device_function sub_rz(x, y) = add_rz(x, -y) +@device_function sub_rm(x, y) = add_rm(x, -y) +@device_function sub_rp(x, y) = add_rp(x, -y) + +@device_function add(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = add_rn(x, y) +@device_function add(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = add_rz(x, y) +@device_function add(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = add_rm(x, y) +@device_function add(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = add_rp(x, y) + +@device_function sub(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = sub_rn(x, y) +@device_function sub(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = sub_rz(x, y) +@device_function sub(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = sub_rm(x, y) +@device_function sub(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = sub_rp(x, y) + +@device_function mul(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = mul_rn(x, y) +@device_function mul(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = mul_rz(x, y) +@device_function mul(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = mul_rm(x, y) +@device_function mul(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = mul_rp(x, y) + +@device_function div(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = div_rn(x, y) +@device_function div(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = div_rz(x, y) +@device_function div(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = div_rm(x, y) +@device_function div(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = div_rp(x, y) + + + @device_override Base.fma(x::Float64, y::Float64, z::Float64) = ccall("extern __nv_fma", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) @device_override Base.fma(x::Float32, y::Float32, z::Float32) = ccall("extern __nv_fmaf", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) @device_override Base.fma(x::Float16, y::Float16, z::Float16) = ccall("llvm.fma.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z) +@device_function fma_rn(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rn.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) +@device_function fma_rn(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rn.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) +@device_function fma_rz(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rz.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) +@device_function fma_rz(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rz.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) +@device_function fma_rm(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rm.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) +@device_function fma_rm(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rm.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) +@device_function fma_rp(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rp.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) +@device_function fma_rp(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rp.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) + +@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = fma_rn(x, y, z) +@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = fma_rz(x, y, z) +@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = fma_rm(x, y, z) +@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = fma_rp(x, y, z) @device_function sad(x::Int32, y::Int32, z::Int32) = ccall("extern __nv_sad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z) @device_function sad(x::UInt32, y::UInt32, z::UInt32) = convert(UInt32, ccall("extern __nv_usad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)) diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index b40bbffe2d..f6da7d90c9 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -27,7 +27,9 @@ const map_ptx_to_jl_frag = Dict( "f32" => Float32 ) -# Maps matrix & PTX types to fragment sizes +# Maps matrix & PTX types to fragment sizes, information retrieved from +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=wmma#matrix-fragments-for-wmma + const map_frag_sizes = Dict( # A "a.u8.m16n16k16" => 2, @@ -491,7 +493,9 @@ julia> config = WMMA.Config{16, 16, 16, Float32} CUDA.WMMA.Config{16, 16, 16, Float32} ``` """ -struct Config{M, N, K, d_type} end +struct ConfigRounding{M, N, K, d_type, rounding} end + +Config{M, N, K, d_type} = ConfigRounding{M, N, K, d_type, RoundNearest} # --------- # Constants