diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index c12fe526ca..b95d9033d8 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -1,7 +1,7 @@ export WMMA module WMMA -using ..CUDA: AS +using ..CUDA: AS, BFloat16 using Core: LLVMPtr ################################################################################ @@ -14,6 +14,7 @@ const map_ptx_to_jl_array = Dict( "s8" => Int8, "s32" => Int32, "f16" => Float16, + "bf16" => BFloat16, "f32" => Float32 ) @@ -23,6 +24,7 @@ const map_ptx_to_jl_frag = Dict( "s8" => UInt32, "s32" => Int32, "f16" => NTuple{2, VecElement{Float16}}, + "bf16" => Float32, "f32" => Float32 ) @@ -40,6 +42,10 @@ const map_frag_sizes = Dict( "a.f16.m16n16k16" => 8, "a.f16.m8n32k16" => 8, "a.f16.m32n8k16" => 8, + + "a.bf16.m16n16k16" => 4, + "a.bf16.m8n32k16" => 2, + "a.bf16.m32n8k16" => 8, # B "b.u8.m16n16k16" => 2, "b.u8.m8n32k16" => 4, @@ -52,7 +58,11 @@ const map_frag_sizes = Dict( "b.f16.m16n16k16" => 8, "b.f16.m8n32k16" => 8, "b.f16.m32n8k16" => 8, - # C + + "b.bf16.m16n16k16" => 4, + "b.bf16.m8n32k16" => 8, + "b.bf16.m32n8k16" => 2, + # C "c.s32.m16n16k16" => 8, "c.s32.m8n32k16" => 8, "c.s32.m32n8k16" => 8, @@ -95,10 +105,14 @@ const wmma_half_ops = [(16,16,16), (32,8,16), (8,32,16)], ["f16"], ["f16", "f const ldst_int_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["u8", "s8"] const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"] const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"] +# BFloat16 +const ldst_bf16_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["bf16"] +const wmma_bf16_ops = [(16,16,16), (32,8,16), (8,32,16)], ["bf16"], ["f32"], ["f32"] const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, - ldst_int_ab_ops, ldst_int_cd_ops) -const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops) + ldst_int_ab_ops, ldst_int_cd_ops, + ldst_bf16_ab_ops) +const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_bf16_ops) # Valid WMMA operation shapes const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)] @@ -309,7 +323,7 @@ for ops in all_wmma_ops, # Name of the LLVM intrinsic # If integer/sub-byte/bit A/B types, name is determined by A/B types - if d_elem_type == "s32" + if d_elem_type == "s32" || a_elem_type == "bf16" llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$a_elem_type" # Name of the Julia wrapper function func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type]), "_"))