Skip to content

Commit

Permalink
Revert wmma.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
orkolorko committed Dec 20, 2024
1 parent 30105ee commit 3c2d721
Showing 1 changed file with 16 additions and 118 deletions.
134 changes: 16 additions & 118 deletions src/device/intrinsics/wmma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ const map_ptx_to_jl_array = Dict(
"s8" => Int8,
"s32" => Int32,
"f16" => Float16,
"f32" => Float32,
"f64" => Float64
"f32" => Float32
)

# Maps PTX types to Julia fragment types
Expand All @@ -25,8 +24,7 @@ const map_ptx_to_jl_frag = Dict(
"s8" => UInt32,
"s32" => Int32,
"f16" => NTuple{2, VecElement{Float16}},
"f32" => Float32,
"f64" => Float64
"f32" => Float32
)

# Maps matrix & PTX types to fragment sizes, information retrieved from
Expand All @@ -45,9 +43,6 @@ const map_frag_sizes = Dict(
"a.f16.m16n16k16" => 8,
"a.f16.m8n32k16" => 8,
"a.f16.m32n8k16" => 8,

"a.f64.m8n8k4" => 1,

# B
"b.u8.m16n16k16" => 2,
"b.u8.m8n32k16" => 4,
Expand All @@ -60,9 +55,6 @@ const map_frag_sizes = Dict(
"b.f16.m16n16k16" => 8,
"b.f16.m8n32k16" => 8,
"b.f16.m32n8k16" => 8,

"b.f64.m8n8k4" => 1,

# C
"c.s32.m16n16k16" => 8,
"c.s32.m8n32k16" => 8,
Expand All @@ -75,12 +67,6 @@ const map_frag_sizes = Dict(
"c.f32.m16n16k16" => 8,
"c.f32.m8n32k16" => 8,
"c.f32.m32n8k16" => 8,

"c.f64.m8n8k4" => 2, # there is a clash of documentation here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-fragments-for-mma-m8n8k4-with-f64-floating-point-type
# says `A vector expression containing of two .f64 registers containing two .f64 elements from the matrix C.`
# while https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-fragments-for-wmma says 1

# D
"d.s32.m16n16k16" => 8,
"d.s32.m8n32k16" => 8,
Expand All @@ -93,8 +79,6 @@ const map_frag_sizes = Dict(
"d.f32.m16n16k16" => 8,
"d.f32.m8n32k16" => 8,
"d.f32.m32n8k16" => 8,

"d.f64.m8n8k4" => 2,
)

# Maps PTX AS to CUDA.AS
Expand All @@ -114,19 +98,13 @@ const wmma_half_ops = [(16,16,16), (32,8,16), (8,32,16)], ["f16"], ["f16", "f
const ldst_int_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["u8", "s8"]
const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"]
const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"]
# Double
const ldst_double_ab_ops = [(8, 8, 4)], ["a", "b"], ["f64"]
const ldst_double_cd_ops = [(8, 8, 4)], ["c", "d"], ["f64"]
const wmma_double_ops = [(8, 8, 4)], ["f64"], ["f64"], ["f64"]

const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, ldst_double_ab_ops,
ldst_int_ab_ops, ldst_int_cd_ops, ldst_double_cd_ops)

# the wmma_double_ops will be treated separatedly due to rounding

const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops,
ldst_int_ab_ops, ldst_int_cd_ops)
const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops)

# Valid WMMA operation shapes
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16), (8, 8, 4)]
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)]

################################################################################
# HELPER FUNCTIONS
Expand Down Expand Up @@ -280,21 +258,20 @@ export llvm_wmma_store
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type]), "_"))

# Name of the LLVM intrinsic
#llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64
llvm_intr = "llvm.nvvm.wmma.$shape.store.$mat.$layout.stride.$elem_type.p$(addr_space_int)"
if LLVM.version() < v"17"
llvm_intr *= "i8"
end

# Determine types + size for this (matrix, elem_type) combination
arr_ty, frag_ty, sz = get_frag_info(mat, elem_type, shape)

ccall_name = "$llvm_intr"
frag_types = ntuple(i -> frag_ty, sz)
frag_vars = ntuple(i -> :(data[$i]), sz)

ptr_ty = :(LLVMPtr{$arr_ty, $addr_space_int})

@eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride)
@eval export $func_name
@eval @doc (@doc llvm_wmma_store) $func_name
Expand All @@ -308,7 +285,6 @@ end
WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{d_elem_type}_{c_elem_type}(a, b, c) or
WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{a_elem_type}(a, b, c)
For double operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{rnd}.{d_elem_type}.{c_elem_type}`
For floating point operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{d_elem_type}.{c_elem_type}`
For all other operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{a_elem_type}`
Expand Down Expand Up @@ -372,59 +348,6 @@ for ops in all_wmma_ops,
b_vars = ntuple(i -> :(b[$i]), b_sz)
c_vars = ntuple(i -> :(c[$i]), c_sz)

if d_sz == 1
@eval $func_name(a, b, c) = tuple(ccall($ccall_name, llvmcall, $d_frag_ty, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
else
struct_ty = Symbol("LLVMStruct$d_sz")
@eval $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
@eval $func_name(a, b, c, ::RoundingMode{:Nearest}) = $func_name(a, b, c)
@eval $func_name(a, b, c, ::RoundingMode{:ToZero}) = $func_name(a, b, c)
@eval $func_name(a, b, c, ::RoundingMode{:Up}) = $func_name(a, b, c)
@eval $func_name(a, b, c, ::RoundingMode{:Down}) = $func_name(a, b, c)
end
@eval export $func_name
@eval @doc (@doc llvm_wmma_mma) $func_name
end

const wmma_double_rounding = ["", "rn", "rz", "rm", "rp"]

for ops in [wmma_double_ops],
a_layout in ["col", "row"],
b_layout in ["col", "row"],
mnk in ops[1],
rnd in wmma_double_rounding

a_elem_type = "f64"
b_elem_type = "f64"
c_elem_type = "f64"
d_elem_type = "f64"

shape = get_hl_shape(mnk[1], mnk[2], mnk[3])

llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$rnd.f64"
if rnd == ""
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.f64"
end
# Name of the Julia wrapper function
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_"))
func_name_no_round = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_"))

# Determine types + size for the (matrix, elem_type) combinations for matrix A, B, C and D
a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", a_elem_type, shape)
b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", b_elem_type, shape)
c_arr_ty, c_frag_ty, c_sz = get_frag_info("c", c_elem_type, shape)
d_arr_ty, d_frag_ty, d_sz = get_frag_info("d", d_elem_type, shape)

ccall_name = "$llvm_intr"

a_types = ntuple(i -> a_frag_ty, a_sz)
b_types = ntuple(i -> b_frag_ty, b_sz)
c_types = ntuple(i -> c_frag_ty, c_sz)

a_vars = ntuple(i -> :(a[$i]), a_sz)
b_vars = ntuple(i -> :(b[$i]), b_sz)
c_vars = ntuple(i -> :(c[$i]), c_sz)

if d_sz == 1
@eval $func_name(a, b, c) = tuple(ccall($ccall_name, llvmcall, $d_frag_ty, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
else
Expand All @@ -435,33 +358,6 @@ for ops in [wmma_double_ops],
@eval @doc (@doc llvm_wmma_mma) $func_name
end

# TODO, rewrite this as a macro

llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_col_col_m8n8k4_f64_rn(a_frag, b_frag, c_frag)
llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_col_col_m8n8k4_f64_rz(a_frag, b_frag, c_frag)
llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_col_col_m8n8k4_f64_rp(a_frag, b_frag, c_frag)
llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_col_col_m8n8k4_f64_rm(a_frag, b_frag, c_frag)
llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_row_col_m8n8k4_f64_rn(a_frag, b_frag, c_frag)
llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_row_col_m8n8k4_f64_rz(a_frag, b_frag, c_frag)
llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_row_col_m8n8k4_f64_rp(a_frag, b_frag, c_frag)
llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_row_col_m8n8k4_f64_rm(a_frag, b_frag, c_frag)
llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_col_row_m8n8k4_f64_rn(a_frag, b_frag, c_frag)
llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_col_row_m8n8k4_f64_rz(a_frag, b_frag, c_frag)
llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_col_row_m8n8k4_f64_rp(a_frag, b_frag, c_frag)
llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_col_row_m8n8k4_f64_rm(a_frag, b_frag, c_frag)
llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_row_row_m8n8k4_f64_rn(a_frag, b_frag, c_frag)
llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_row_row_m8n8k4_f64_rz(a_frag, b_frag, c_frag)
llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_row_row_m8n8k4_f64_rp(a_frag, b_frag, c_frag)
llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_row_row_m8n8k4_f64_rm(a_frag, b_frag, c_frag)


# elseif d_elem_type == "f64"
# llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$rnd.f64.f64.f64.f64"
# # Name of the Julia wrapper function
# func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_"))



################################################################################
# FLATTENING/UNFLATTENING LOGIC
################################################################################
Expand Down Expand Up @@ -587,17 +483,19 @@ Type that contains all information for WMMA operations that cannot be inferred f
WMMA instructions calculate the matrix multiply-accumulate operation ``D = A \\cdot B + C``, where ``A`` is a ``M \\times K`` matrix,
``B`` a ``K \\times N`` matrix, and ``C`` and ``D`` are ``M \\times N`` matrices.
`d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16`, `Float32` or `Float64`.
`d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16` or `Float32`.
All WMMA operations take a `Config` as their final argument.
# Examples
```jldoctest
config = WMMA.Config{16, 16, 16, Float64}
CUDA.WMMA.Config{16, 16, 16, Float64}
julia> config = WMMA.Config{16, 16, 16, Float32}
CUDA.WMMA.Config{16, 16, 16, Float32}
```
"""
struct Config{M, N, K, d_type} end
struct ConfigRounding{M, N, K, d_type, rounding} end

Config{M, N, K, d_type} = ConfigRounding{M, N, K, d_type, RoundNearest}

# ---------
# Constants
Expand Down Expand Up @@ -777,7 +675,7 @@ mma
b_unfl = unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x)
c_unfl = unflatten(NTuple{$c_frag_sz, $c_frag_ty}, c.x)

x = flatten($wrapper(a_unfl, b_unfl, c_unfl, rounding))
x = flatten($wrapper(a_unfl, b_unfl, c_unfl))
return Fragment{$M, $N, $K, $d_num_els, $D_T, Unspecified, Accumulator}(x)
end
end
Expand Down

0 comments on commit 3c2d721

Please sign in to comment.