diff --git a/docs/src/hacking/exposing_new_intrinsics.jl b/docs/src/hacking/exposing_new_intrinsics.jl index 08440af1fe..662940af27 100644 --- a/docs/src/hacking/exposing_new_intrinsics.jl +++ b/docs/src/hacking/exposing_new_intrinsics.jl @@ -11,6 +11,8 @@ # where `rn` is round to nearest, `rz` round to zero, `rm` round to minus infinity, `rp` round to plus infinity. # When building the intrinsic for the call, we need to change the type `.f64` with `.d` and `.f32` with `.f` # Therefore, to call the rounded towards infinity `fma` for `.f64` we need to call the intrinsic `llvm.nvvm.fma.rp.d` +# Please remark that this is only possible if LLVM support the intrinsic; a source for those exposed by LLVM +# may be found by searching the [LLVM repository](https://github.com/llvm/llvm-project). In in other cases you'd need @asmcall and inline PTX assembly. fma_rp(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rp.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) fma(x::T, y::T, z::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = fma_rp(x, y, z) @@ -21,6 +23,7 @@ CUDA.code_ptx(fma_rp, Tuple{Float64,Float64,Float64}) # It is possible to see that the PTX code contains a call to the intrinsic `fma.rp.f64`; we add this function now # to src/device/intrins/math.jl +using CUDA function test_fma!(out, x, y) I = threadIdx().x z = (2.0) ^ (-(I+53)) @@ -44,58 +47,3 @@ out_d = CuArray(zeros(4)) @cuda threads = 4 test_fma!(out_d, -1.0, 1.0) out_h = Array(out_d) -# The binary operations as add, sub, mul, div have been implemented through a macro - -function test_add!(out, x, y) - I = threadIdx().x - if I == 1 - out[I] = CUDA.add(x, y, RoundNearest) - elseif I == 2 - out[I] = CUDA.add(x, y, RoundToZero) - elseif I == 3 - out[I] = CUDA.add(x, y, RoundUp) - elseif I == 4 - out[I] = CUDA.add(x, y, RoundDown) - end - return -end - -out_d = CuArray(zeros(4)) -@cuda threads = 4 test_add!(out_d, 1.0, 2^(-54)) -out_h = Array(out_d) - -function test_sub!(out, x, y) - I = threadIdx().x - if I == 1 - out[I] = CUDA.sub(x, y, RoundNearest) - elseif I == 2 - out[I] = CUDA.sub(x, y, RoundToZero) - elseif I == 3 - out[I] = CUDA.sub(x, y, RoundUp) - elseif I == 4 - out[I] = CUDA.sub(x, y, RoundDown) - end - return -end - -out_d = CuArray(zeros(4)) -@cuda threads = 4 test_sub!(out_d, 1.0, 2^(-53)) -out_h = Array(out_d) - -function test_mul!(out, x, y) - I = threadIdx().x - if I == 1 - out[I] = CUDA.mul(x, y, RoundNearest) - elseif I == 2 - out[I] = CUDA.mul(x, y, RoundToZero) - elseif I == 3 - out[I] = CUDA.mul(x, y, RoundUp) - elseif I == 4 - out[I] = CUDA.mul(x, y, RoundDown) - end - return -end - -out_d = CuArray(zeros(4)) -@cuda threads = 4 test_mul!(out_d, 1.0 - 2^(-52), 1.0 + 2^(-52)) -out_h = Array(out_d) diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index c02b0370bf..34755138bd 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -15,7 +15,8 @@ const map_ptx_to_jl_array = Dict( "s8" => Int8, "s32" => Int32, "f16" => Float16, - "f32" => Float32 + "f32" => Float32, + "f64" => Float64 ) # Maps PTX types to Julia fragment types @@ -24,10 +25,13 @@ const map_ptx_to_jl_frag = Dict( "s8" => UInt32, "s32" => Int32, "f16" => NTuple{2, VecElement{Float16}}, - "f32" => Float32 + "f32" => Float32, + "f64" => Float64 ) -# Maps matrix & PTX types to fragment sizes +# Maps matrix & PTX types to fragment sizes, information retrieved from +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=wmma#matrix-fragments-for-wmma + const map_frag_sizes = Dict( # A "a.u8.m16n16k16" => 2, @@ -41,6 +45,9 @@ const map_frag_sizes = Dict( "a.f16.m16n16k16" => 8, "a.f16.m8n32k16" => 8, "a.f16.m32n8k16" => 8, + + "a.f64.m8n8k4" => 1, + # B "b.u8.m16n16k16" => 2, "b.u8.m8n32k16" => 4, @@ -53,6 +60,9 @@ const map_frag_sizes = Dict( "b.f16.m16n16k16" => 8, "b.f16.m8n32k16" => 8, "b.f16.m32n8k16" => 8, + + "b.f64.m8n8k4" => 1, + # C "c.s32.m16n16k16" => 8, "c.s32.m8n32k16" => 8, @@ -65,6 +75,12 @@ const map_frag_sizes = Dict( "c.f32.m16n16k16" => 8, "c.f32.m8n32k16" => 8, "c.f32.m32n8k16" => 8, + + "c.f64.m8n8k4" => 2, # there is a clash of documentation here: + # https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-fragments-for-mma-m8n8k4-with-f64-floating-point-type + # says `A vector expression containing of two .f64 registers containing two .f64 elements from the matrix C.` + # while https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-fragments-for-wmma says 1 + # D "d.s32.m16n16k16" => 8, "d.s32.m8n32k16" => 8, @@ -77,6 +93,8 @@ const map_frag_sizes = Dict( "d.f32.m16n16k16" => 8, "d.f32.m8n32k16" => 8, "d.f32.m32n8k16" => 8, + + "d.f64.m8n8k4" => 2, ) # Maps PTX AS to CUDA.AS @@ -96,13 +114,19 @@ const wmma_half_ops = [(16,16,16), (32,8,16), (8,32,16)], ["f16"], ["f16", "f const ldst_int_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["u8", "s8"] const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"] const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"] - -const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, - ldst_int_ab_ops, ldst_int_cd_ops) +# Double +const ldst_double_ab_ops = [(8, 8, 4)], ["a", "b"], ["f64"] +const ldst_double_cd_ops = [(8, 8, 4)], ["c", "d"], ["f64"] +const wmma_double_ops = [(8, 8, 4)], ["f64"], ["f64"], ["f64"] + +const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, ldst_double_ab_ops, + ldst_int_ab_ops, ldst_int_cd_ops, ldst_double_cd_ops) + +# the wmma_double_ops will be treated separatedly due to rounding const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops) # Valid WMMA operation shapes -const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)] +const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16), (8, 8, 4)] ################################################################################ # HELPER FUNCTIONS @@ -256,6 +280,7 @@ export llvm_wmma_store func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type]), "_")) # Name of the LLVM intrinsic + #llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64 llvm_intr = "llvm.nvvm.wmma.$shape.store.$mat.$layout.stride.$elem_type.p$(addr_space_int)" if LLVM.version() < v"17" llvm_intr *= "i8" @@ -263,13 +288,13 @@ export llvm_wmma_store # Determine types + size for this (matrix, elem_type) combination arr_ty, frag_ty, sz = get_frag_info(mat, elem_type, shape) - + ccall_name = "$llvm_intr" frag_types = ntuple(i -> frag_ty, sz) frag_vars = ntuple(i -> :(data[$i]), sz) - + ptr_ty = :(LLVMPtr{$arr_ty, $addr_space_int}) - + @eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride) @eval export $func_name @eval @doc (@doc llvm_wmma_store) $func_name @@ -283,6 +308,7 @@ end WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{d_elem_type}_{c_elem_type}(a, b, c) or WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{a_elem_type}(a, b, c) +For double operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{rnd}.{d_elem_type}.{c_elem_type}` For floating point operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{d_elem_type}.{c_elem_type}` For all other operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{a_elem_type}` @@ -351,11 +377,91 @@ for ops in all_wmma_ops, else struct_ty = Symbol("LLVMStruct$d_sz") @eval $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))) + @eval $func_name(a, b, c, ::RoundingMode{:Nearest}) = $func_name(a, b, c) + @eval $func_name(a, b, c, ::RoundingMode{:ToZero}) = $func_name(a, b, c) + @eval $func_name(a, b, c, ::RoundingMode{:Up}) = $func_name(a, b, c) + @eval $func_name(a, b, c, ::RoundingMode{:Down}) = $func_name(a, b, c) end @eval export $func_name @eval @doc (@doc llvm_wmma_mma) $func_name end +const wmma_double_rounding = ["", "rn", "rz", "rm", "rp"] + +for ops in [wmma_double_ops], + a_layout in ["col", "row"], + b_layout in ["col", "row"], + mnk in ops[1], + rnd in wmma_double_rounding + + a_elem_type = "f64" + b_elem_type = "f64" + c_elem_type = "f64" + d_elem_type = "f64" + + shape = get_hl_shape(mnk[1], mnk[2], mnk[3]) + + llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$rnd.f64" + if rnd == "" + llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.f64" + end + # Name of the Julia wrapper function + func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_")) + func_name_no_round = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_")) + + # Determine types + size for the (matrix, elem_type) combinations for matrix A, B, C and D + a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", a_elem_type, shape) + b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", b_elem_type, shape) + c_arr_ty, c_frag_ty, c_sz = get_frag_info("c", c_elem_type, shape) + d_arr_ty, d_frag_ty, d_sz = get_frag_info("d", d_elem_type, shape) + + ccall_name = "$llvm_intr" + + a_types = ntuple(i -> a_frag_ty, a_sz) + b_types = ntuple(i -> b_frag_ty, b_sz) + c_types = ntuple(i -> c_frag_ty, c_sz) + + a_vars = ntuple(i -> :(a[$i]), a_sz) + b_vars = ntuple(i -> :(b[$i]), b_sz) + c_vars = ntuple(i -> :(c[$i]), c_sz) + + if d_sz == 1 + @eval $func_name(a, b, c) = tuple(ccall($ccall_name, llvmcall, $d_frag_ty, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))) + else + struct_ty = Symbol("LLVMStruct$d_sz") + @eval $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))) + end + @eval export $func_name + @eval @doc (@doc llvm_wmma_mma) $func_name +end + +# TODO, rewrite this as a macro + +llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_col_col_m8n8k4_f64_rn(a_frag, b_frag, c_frag) +llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_col_col_m8n8k4_f64_rz(a_frag, b_frag, c_frag) +llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_col_col_m8n8k4_f64_rp(a_frag, b_frag, c_frag) +llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_col_col_m8n8k4_f64_rm(a_frag, b_frag, c_frag) +llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_row_col_m8n8k4_f64_rn(a_frag, b_frag, c_frag) +llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_row_col_m8n8k4_f64_rz(a_frag, b_frag, c_frag) +llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_row_col_m8n8k4_f64_rp(a_frag, b_frag, c_frag) +llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_row_col_m8n8k4_f64_rm(a_frag, b_frag, c_frag) +llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_col_row_m8n8k4_f64_rn(a_frag, b_frag, c_frag) +llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_col_row_m8n8k4_f64_rz(a_frag, b_frag, c_frag) +llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_col_row_m8n8k4_f64_rp(a_frag, b_frag, c_frag) +llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_col_row_m8n8k4_f64_rm(a_frag, b_frag, c_frag) +llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_row_row_m8n8k4_f64_rn(a_frag, b_frag, c_frag) +llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_row_row_m8n8k4_f64_rz(a_frag, b_frag, c_frag) +llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_row_row_m8n8k4_f64_rp(a_frag, b_frag, c_frag) +llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_row_row_m8n8k4_f64_rm(a_frag, b_frag, c_frag) + + +# elseif d_elem_type == "f64" +# llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$rnd.f64.f64.f64.f64" +# # Name of the Julia wrapper function +# func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_")) + + + ################################################################################ # FLATTENING/UNFLATTENING LOGIC ################################################################################ @@ -481,19 +587,17 @@ Type that contains all information for WMMA operations that cannot be inferred f WMMA instructions calculate the matrix multiply-accumulate operation ``D = A \\cdot B + C``, where ``A`` is a ``M \\times K`` matrix, ``B`` a ``K \\times N`` matrix, and ``C`` and ``D`` are ``M \\times N`` matrices. -`d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16` or `Float32`. +`d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16`, `Float32` or `Float64`. All WMMA operations take a `Config` as their final argument. # Examples ```jldoctest -julia> config = WMMA.Config{16, 16, 16, Float32} -CUDA.WMMA.Config{16, 16, 16, Float32} +config = WMMA.Config{16, 16, 16, Float64} +CUDA.WMMA.Config{16, 16, 16, Float64} ``` """ -struct ConfigRounding{M, N, K, d_type, rounding} end - -Config{M, N, K, d_type} = ConfigRounding{M, N, K, d_type, RoundNearest} +struct Config{M, N, K, d_type} end # --------- # Constants @@ -673,7 +777,7 @@ mma b_unfl = unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x) c_unfl = unflatten(NTuple{$c_frag_sz, $c_frag_ty}, c.x) - x = flatten($wrapper(a_unfl, b_unfl, c_unfl)) + x = flatten($wrapper(a_unfl, b_unfl, c_unfl, rounding)) return Fragment{$M, $N, $K, $d_num_els, $D_T, Unspecified, Accumulator}(x) end end