Skip to content

Commit

Permalink
[LLVMGPU][ROCm] Plumb through i8, i8 -> i32 MFMA intrinsics (iree-org…
Browse files Browse the repository at this point in the history
…#17764)

Add tests to make sure these are generated in the vector distribution
pipeline. Add e2e correctness tests.

I also tested this manually on random inputs against golden outputs from
numpy.

This contains one cherry-pick for llvm-project.

---------

Co-authored-by: Stanley Winata <[email protected]>
Co-authored-by: Lei Zhang <[email protected]>
  • Loading branch information
3 people authored Jun 28, 2024
1 parent 4294a5b commit dcba7c5
Show file tree
Hide file tree
Showing 11 changed files with 274 additions and 17 deletions.
4 changes: 2 additions & 2 deletions compiler/plugins/target/ROCM/test/target_device_features.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
// GFX942: target = #iree_gpu.target<arch = "gfx942",
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
// GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
// GFX942-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>],
// GFX942-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>,
// GFX942-SAME: chip = <wgp_count = 304>>

// GFX940: target = #iree_gpu.target<arch = "gfx940",
// GFX940-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>]
// GFX940-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],

// GFX1100: target = #iree_gpu.target<arch = "gfx1100",
// GFX1100-SAME: mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>]
Expand Down
92 changes: 86 additions & 6 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,23 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
MMAIntrinsic type) {
Type f16 = Float16Type::get(context);
Type f32 = Float32Type::get(context);

Type i8 = IntegerType::get(context, 8);
Type i32 = IntegerType::get(context, 32);

switch (type) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
return OpaqueMmaLayout{16, 16, 16, f16, f16, f32};
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
return OpaqueMmaLayout{32, 32, 8, f16, f16, f32};
}
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
return OpaqueMmaLayout{16, 16, 32, i8, i8, i32};
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
return OpaqueMmaLayout{32, 32, 16, i8, i8, i32};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return OpaqueMmaLayout{16, 16, 16, f16, f16, f32};
}
Expand Down Expand Up @@ -280,13 +290,47 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 8]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
// #layout_b = #iree_vector_ext.layout<#inner, #outer>

auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 8});
auto aMLayout = outer;
auto aKLayout = inner;
auto bKLayout = inner;
auto bNLayout = outer;
auto cMLayout = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4});
auto cNLayout = outer;
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 8]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
// #layout_b = #iree_vector_ext.layout<#inner, #outer>

auto outer = PerDimLayoutAttr::get(context, {laneX}, {32});
auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {2, 8});
auto aMLayout = outer;
auto aKLayout = inner;
auto bKLayout = inner;
auto bNLayout = outer;
auto cMLayout =
PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {4, 2, 4});
auto cNLayout = outer;
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F32:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [1, 16]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
// #layout_b = #iree_vector_ext.layout<#inner, #outer>
// #layout_c = #iree_vector_ext.layout<#inner, #outer>

auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {1, 16});
Expand Down Expand Up @@ -372,6 +416,18 @@ MMAAttr::getABCVectorTypes() const {
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
auto aType = VectorType::get({8}, getAType());
auto bType = VectorType::get({8}, getBType());
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
auto aType = VectorType::get({8}, getAType());
auto bType = VectorType::get({8}, getBType());
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::WMMA_F16_16x16x16_F32:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
auto aType = VectorType::get({16}, getAType());
Expand All @@ -396,6 +452,8 @@ int64_t MMAAttr::getBlockSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32:
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return 1;
Expand All @@ -408,7 +466,9 @@ int64_t MMAAttr::getBlockSize() const {
int64_t MMAAttr::getSubgroupSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
return 64;
}
case MMAIntrinsic::WMMA_F16_16x16x16_F32:
Expand All @@ -430,6 +490,14 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*strides=*/{1, 32},
/*element=*/{1, 4}};
}
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 8}};
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*strides=*/{1, 32},
/*element=*/{1, 8}};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F32:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 1}, /*strides=*/{1, 16},
Expand All @@ -449,6 +517,14 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{8, 1}};
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{8, 1}};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F32:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{1, 16}, /*strides=*/{16, 1},
Expand All @@ -460,11 +536,13 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {

MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{4, 1}};
}
Expand Down Expand Up @@ -494,7 +572,9 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
}
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
auto [m, n, k] = getMNKShape();
return builder
.create<amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize(), lhs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,18 @@ class IREEGPU_I32MmaEnumAttr<string name, string summary, list<I32EnumAttrCase>
// Format: <kind>_<input-type>_<M>x<N>x<K>_<output-type>
def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 0>;
def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 1>;
def MFMA_I8_16x16x32_I32 : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 2>;
def MFMA_I8_32x32x16_I32 : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 3>;
// TODO: Create separate WMMA ops for AMD and NVIDIA GPUs
def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 2>;
def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 3>;
def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 4>;
def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 5>;

def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
"Descriptor for different MMA intrinsics", [
MFMA_F16_16x16x16_F32,
MFMA_F16_32x32x8_F32,
MFMA_I8_16x16x32_I32,
MFMA_I8_32x32x16_I32,
WMMA_F16_16x16x16_F32,
WMMA_F16_16x16x16_F16
]>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ const WgpDetails *getCDNA3WgpDetails() {
static const MMAIntrinsic cdna3MMAOps[] = {
MMAIntrinsic::MFMA_F16_16x16x16_F32,
MMAIntrinsic::MFMA_F16_32x32x8_F32,
MMAIntrinsic::MFMA_I8_16x16x32_I32,
MMAIntrinsic::MFMA_I8_32x32x16_I32,
};
static const WgpDetails cdna3Wgp = {
allComputeBits, allStorageBits, allSubgroupOps,
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/Flow/Transforms",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ iree_cc_library(
iree::compiler::Codegen::Utils
iree::compiler::Codegen::Utils::VectorOpUtils
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Flow::Transforms
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::LinalgExt::IR
iree::compiler::Dialect::LinalgExt::Transforms
Expand Down
11 changes: 11 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/LinalgOpInfo.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/BuiltinAttributes.h"
Expand Down Expand Up @@ -444,6 +446,15 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
Type rhsElemType = getElementTypeOrSelf(rhs);
Type initElemType = getElementTypeOrSelf(init);

if (auto lhsOp = lhs.getDefiningOp<linalg::GenericOp>()) {
if (IREE::Flow::isDequantizationLikeOp(lhsOp))
lhsElemType = getElementTypeOrSelf(lhsOp.getDpsInputs()[0]);
}
if (auto rhsOp = rhs.getDefiningOp<linalg::GenericOp>()) {
if (IREE::Flow::isDequantizationLikeOp(rhsOp))
rhsElemType = getElementTypeOrSelf(rhsOp.getDpsInputs()[0]);
}

GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
lhsElemType, rhsElemType, initElemType};

Expand Down
Loading

0 comments on commit dcba7c5

Please sign in to comment.