Skip to content

Commit

Permalink
Merge branch 'DirectedRounding' of github.com:orkolorko/CUDA.jl into …
Browse files Browse the repository at this point in the history
…DirectedRounding
  • Loading branch information
orkolorko committed Dec 20, 2024
2 parents 5f8fff5 + 48c36d0 commit 30105ee
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 72 deletions.
58 changes: 3 additions & 55 deletions docs/src/hacking/exposing_new_intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# where `rn` is round to nearest, `rz` round to zero, `rm` round to minus infinity, `rp` round to plus infinity.
# When building the intrinsic for the call, we need to change the type `.f64` with `.d` and `.f32` with `.f`
# Therefore, to call the rounded towards infinity `fma` for `.f64` we need to call the intrinsic `llvm.nvvm.fma.rp.d`
# Please remark that this is only possible if LLVM support the intrinsic; a source for those exposed by LLVM
# may be found by searching the [LLVM repository](https://github.com/llvm/llvm-project). In in other cases you'd need @asmcall and inline PTX assembly.

fma_rp(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rp.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z)
fma(x::T, y::T, z::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = fma_rp(x, y, z)
Expand All @@ -21,6 +23,7 @@ CUDA.code_ptx(fma_rp, Tuple{Float64,Float64,Float64})
# It is possible to see that the PTX code contains a call to the intrinsic `fma.rp.f64`; we add this function now
# to src/device/intrins/math.jl

using CUDA
function test_fma!(out, x, y)
I = threadIdx().x
z = (2.0) ^ (-(I+53))
Expand All @@ -44,58 +47,3 @@ out_d = CuArray(zeros(4))
@cuda threads = 4 test_fma!(out_d, -1.0, 1.0)
out_h = Array(out_d)

# The binary operations as add, sub, mul, div have been implemented through a macro

function test_add!(out, x, y)
I = threadIdx().x
if I == 1
out[I] = CUDA.add(x, y, RoundNearest)
elseif I == 2
out[I] = CUDA.add(x, y, RoundToZero)
elseif I == 3
out[I] = CUDA.add(x, y, RoundUp)
elseif I == 4
out[I] = CUDA.add(x, y, RoundDown)
end
return
end

out_d = CuArray(zeros(4))
@cuda threads = 4 test_add!(out_d, 1.0, 2^(-54))
out_h = Array(out_d)

function test_sub!(out, x, y)
I = threadIdx().x
if I == 1
out[I] = CUDA.sub(x, y, RoundNearest)
elseif I == 2
out[I] = CUDA.sub(x, y, RoundToZero)
elseif I == 3
out[I] = CUDA.sub(x, y, RoundUp)
elseif I == 4
out[I] = CUDA.sub(x, y, RoundDown)
end
return
end

out_d = CuArray(zeros(4))
@cuda threads = 4 test_sub!(out_d, 1.0, 2^(-53))
out_h = Array(out_d)

function test_mul!(out, x, y)
I = threadIdx().x
if I == 1
out[I] = CUDA.mul(x, y, RoundNearest)
elseif I == 2
out[I] = CUDA.mul(x, y, RoundToZero)
elseif I == 3
out[I] = CUDA.mul(x, y, RoundUp)
elseif I == 4
out[I] = CUDA.mul(x, y, RoundDown)
end
return
end

out_d = CuArray(zeros(4))
@cuda threads = 4 test_mul!(out_d, 1.0 - 2^(-52), 1.0 + 2^(-52))
out_h = Array(out_d)
138 changes: 121 additions & 17 deletions src/device/intrinsics/wmma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ const map_ptx_to_jl_array = Dict(
"s8" => Int8,
"s32" => Int32,
"f16" => Float16,
"f32" => Float32
"f32" => Float32,
"f64" => Float64
)

# Maps PTX types to Julia fragment types
Expand All @@ -24,10 +25,13 @@ const map_ptx_to_jl_frag = Dict(
"s8" => UInt32,
"s32" => Int32,
"f16" => NTuple{2, VecElement{Float16}},
"f32" => Float32
"f32" => Float32,
"f64" => Float64
)

# Maps matrix & PTX types to fragment sizes
# Maps matrix & PTX types to fragment sizes, information retrieved from
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=wmma#matrix-fragments-for-wmma

const map_frag_sizes = Dict(
# A
"a.u8.m16n16k16" => 2,
Expand All @@ -41,6 +45,9 @@ const map_frag_sizes = Dict(
"a.f16.m16n16k16" => 8,
"a.f16.m8n32k16" => 8,
"a.f16.m32n8k16" => 8,

"a.f64.m8n8k4" => 1,

# B
"b.u8.m16n16k16" => 2,
"b.u8.m8n32k16" => 4,
Expand All @@ -53,6 +60,9 @@ const map_frag_sizes = Dict(
"b.f16.m16n16k16" => 8,
"b.f16.m8n32k16" => 8,
"b.f16.m32n8k16" => 8,

"b.f64.m8n8k4" => 1,

# C
"c.s32.m16n16k16" => 8,
"c.s32.m8n32k16" => 8,
Expand All @@ -65,6 +75,12 @@ const map_frag_sizes = Dict(
"c.f32.m16n16k16" => 8,
"c.f32.m8n32k16" => 8,
"c.f32.m32n8k16" => 8,

"c.f64.m8n8k4" => 2, # there is a clash of documentation here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-fragments-for-mma-m8n8k4-with-f64-floating-point-type
# says `A vector expression containing of two .f64 registers containing two .f64 elements from the matrix C.`
# while https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-fragments-for-wmma says 1

# D
"d.s32.m16n16k16" => 8,
"d.s32.m8n32k16" => 8,
Expand All @@ -77,6 +93,8 @@ const map_frag_sizes = Dict(
"d.f32.m16n16k16" => 8,
"d.f32.m8n32k16" => 8,
"d.f32.m32n8k16" => 8,

"d.f64.m8n8k4" => 2,
)

# Maps PTX AS to CUDA.AS
Expand All @@ -96,13 +114,19 @@ const wmma_half_ops = [(16,16,16), (32,8,16), (8,32,16)], ["f16"], ["f16", "f
const ldst_int_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["u8", "s8"]
const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"]
const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"]

const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops,
ldst_int_ab_ops, ldst_int_cd_ops)
# Double
const ldst_double_ab_ops = [(8, 8, 4)], ["a", "b"], ["f64"]
const ldst_double_cd_ops = [(8, 8, 4)], ["c", "d"], ["f64"]
const wmma_double_ops = [(8, 8, 4)], ["f64"], ["f64"], ["f64"]

const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, ldst_double_ab_ops,
ldst_int_ab_ops, ldst_int_cd_ops, ldst_double_cd_ops)

# the wmma_double_ops will be treated separatedly due to rounding
const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops)

# Valid WMMA operation shapes
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)]
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16), (8, 8, 4)]

################################################################################
# HELPER FUNCTIONS
Expand Down Expand Up @@ -256,20 +280,21 @@ export llvm_wmma_store
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type]), "_"))

# Name of the LLVM intrinsic
#llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64
llvm_intr = "llvm.nvvm.wmma.$shape.store.$mat.$layout.stride.$elem_type.p$(addr_space_int)"
if LLVM.version() < v"17"
llvm_intr *= "i8"
end

# Determine types + size for this (matrix, elem_type) combination
arr_ty, frag_ty, sz = get_frag_info(mat, elem_type, shape)

ccall_name = "$llvm_intr"
frag_types = ntuple(i -> frag_ty, sz)
frag_vars = ntuple(i -> :(data[$i]), sz)

ptr_ty = :(LLVMPtr{$arr_ty, $addr_space_int})

@eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride)
@eval export $func_name
@eval @doc (@doc llvm_wmma_store) $func_name
Expand All @@ -283,6 +308,7 @@ end
WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{d_elem_type}_{c_elem_type}(a, b, c) or
WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{a_elem_type}(a, b, c)
For double operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{rnd}.{d_elem_type}.{c_elem_type}`
For floating point operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{d_elem_type}.{c_elem_type}`
For all other operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{a_elem_type}`
Expand Down Expand Up @@ -351,11 +377,91 @@ for ops in all_wmma_ops,
else
struct_ty = Symbol("LLVMStruct$d_sz")
@eval $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
@eval $func_name(a, b, c, ::RoundingMode{:Nearest}) = $func_name(a, b, c)
@eval $func_name(a, b, c, ::RoundingMode{:ToZero}) = $func_name(a, b, c)
@eval $func_name(a, b, c, ::RoundingMode{:Up}) = $func_name(a, b, c)
@eval $func_name(a, b, c, ::RoundingMode{:Down}) = $func_name(a, b, c)
end
@eval export $func_name
@eval @doc (@doc llvm_wmma_mma) $func_name
end

const wmma_double_rounding = ["", "rn", "rz", "rm", "rp"]

for ops in [wmma_double_ops],
a_layout in ["col", "row"],
b_layout in ["col", "row"],
mnk in ops[1],
rnd in wmma_double_rounding

a_elem_type = "f64"
b_elem_type = "f64"
c_elem_type = "f64"
d_elem_type = "f64"

shape = get_hl_shape(mnk[1], mnk[2], mnk[3])

llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$rnd.f64"
if rnd == ""
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.f64"
end
# Name of the Julia wrapper function
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_"))
func_name_no_round = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_"))

# Determine types + size for the (matrix, elem_type) combinations for matrix A, B, C and D
a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", a_elem_type, shape)
b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", b_elem_type, shape)
c_arr_ty, c_frag_ty, c_sz = get_frag_info("c", c_elem_type, shape)
d_arr_ty, d_frag_ty, d_sz = get_frag_info("d", d_elem_type, shape)

ccall_name = "$llvm_intr"

a_types = ntuple(i -> a_frag_ty, a_sz)
b_types = ntuple(i -> b_frag_ty, b_sz)
c_types = ntuple(i -> c_frag_ty, c_sz)

a_vars = ntuple(i -> :(a[$i]), a_sz)
b_vars = ntuple(i -> :(b[$i]), b_sz)
c_vars = ntuple(i -> :(c[$i]), c_sz)

if d_sz == 1
@eval $func_name(a, b, c) = tuple(ccall($ccall_name, llvmcall, $d_frag_ty, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
else
struct_ty = Symbol("LLVMStruct$d_sz")
@eval $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
end
@eval export $func_name
@eval @doc (@doc llvm_wmma_mma) $func_name
end

# TODO, rewrite this as a macro

llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_col_col_m8n8k4_f64_rn(a_frag, b_frag, c_frag)
llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_col_col_m8n8k4_f64_rz(a_frag, b_frag, c_frag)
llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_col_col_m8n8k4_f64_rp(a_frag, b_frag, c_frag)
llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_col_col_m8n8k4_f64_rm(a_frag, b_frag, c_frag)
llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_row_col_m8n8k4_f64_rn(a_frag, b_frag, c_frag)
llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_row_col_m8n8k4_f64_rz(a_frag, b_frag, c_frag)
llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_row_col_m8n8k4_f64_rp(a_frag, b_frag, c_frag)
llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_row_col_m8n8k4_f64_rm(a_frag, b_frag, c_frag)
llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_col_row_m8n8k4_f64_rn(a_frag, b_frag, c_frag)
llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_col_row_m8n8k4_f64_rz(a_frag, b_frag, c_frag)
llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_col_row_m8n8k4_f64_rp(a_frag, b_frag, c_frag)
llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_col_row_m8n8k4_f64_rm(a_frag, b_frag, c_frag)
llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_row_row_m8n8k4_f64_rn(a_frag, b_frag, c_frag)
llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_row_row_m8n8k4_f64_rz(a_frag, b_frag, c_frag)
llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_row_row_m8n8k4_f64_rp(a_frag, b_frag, c_frag)
llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_row_row_m8n8k4_f64_rm(a_frag, b_frag, c_frag)


# elseif d_elem_type == "f64"
# llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$rnd.f64.f64.f64.f64"
# # Name of the Julia wrapper function
# func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_"))



################################################################################
# FLATTENING/UNFLATTENING LOGIC
################################################################################
Expand Down Expand Up @@ -481,19 +587,17 @@ Type that contains all information for WMMA operations that cannot be inferred f
WMMA instructions calculate the matrix multiply-accumulate operation ``D = A \\cdot B + C``, where ``A`` is a ``M \\times K`` matrix,
``B`` a ``K \\times N`` matrix, and ``C`` and ``D`` are ``M \\times N`` matrices.
`d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16` or `Float32`.
`d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16`, `Float32` or `Float64`.
All WMMA operations take a `Config` as their final argument.
# Examples
```jldoctest
julia> config = WMMA.Config{16, 16, 16, Float32}
CUDA.WMMA.Config{16, 16, 16, Float32}
config = WMMA.Config{16, 16, 16, Float64}
CUDA.WMMA.Config{16, 16, 16, Float64}
```
"""
struct ConfigRounding{M, N, K, d_type, rounding} end

Config{M, N, K, d_type} = ConfigRounding{M, N, K, d_type, RoundNearest}
struct Config{M, N, K, d_type} end

# ---------
# Constants
Expand Down Expand Up @@ -673,7 +777,7 @@ mma
b_unfl = unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x)
c_unfl = unflatten(NTuple{$c_frag_sz, $c_frag_ty}, c.x)

x = flatten($wrapper(a_unfl, b_unfl, c_unfl))
x = flatten($wrapper(a_unfl, b_unfl, c_unfl, rounding))
return Fragment{$M, $N, $K, $d_num_els, $D_T, Unspecified, Accumulator}(x)
end
end
Expand Down

0 comments on commit 30105ee

Please sign in to comment.