Skip to content

Commit

Permalink
[AMDGPU] Make fp8 support checks chipset-specific, check shaped types (
Browse files Browse the repository at this point in the history
…#20152)

With RDNA4, we now have cards that allow the OCP FP8 formats (f8E5M2 and
f8E4M3FN). This means that the previous check that caused any module
that used those formats to be rejected needs to be relaxed.

However, the backend uses the same intrinsics for the gfx942-only FNUZ
types and the OCP types. Therefore, the data type validity check needs
to be aware of the chipset being targetted in order to validate that the
dispatch can be compiled correctly. This patch implements this improved
set of checks.

In addition, it adds tests for these checks (sadly, we can't used
expected-error here because the pipeline scrips debug info) and improves
these checks to also look inside vector<> and memref<>.

(This also means that there is now an IREE-side mechanism to reject
compiling FP8 code pre-gfx942.)
  • Loading branch information
krzysz00 authored Mar 5, 2025
1 parent 42b7fc0 commit ec128bf
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 14 deletions.
53 changes: 39 additions & 14 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,38 @@ static void populateConvertGPUToAMDGPUPatterns(RewritePatternSet &patterns) {

} // namespace

template <typename... Floats>
static bool containsAPred(Type type) {
type = getElementTypeOrSelf(type);
return llvm::isa<Floats...>(type);
}

// Function to check valid data types on the ROCm backend.
static LogicalResult validateDataTypes(Operation *op) {
auto operandTypes = llvm::to_vector(op->getOperandTypes());
auto resultTypes = llvm::to_vector(op->getResultTypes());
if (llvm::any_of(llvm::concat<Type>(operandTypes, resultTypes),
llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type>)) {
op->emitOpError()
<< "F8E5M2 and F8E4M3FN types are not supported on "
"the ROCm backend; try F8E5M2FNUZ or F8E4M3FNUZ instead.";
return failure();
// Note to readers: different chips take different FP8 formats but re-use the
// same instruction and intrinsic names, so we must filter out the "wrong" FP8
// here.
static LogicalResult validateDataTypes(Operation *op,
const amdgpu::Chipset &chipset) {
constexpr amdgpu::Chipset kGfx942 = amdgpu::Chipset(9, 4, 2);
if (!amdgpu::hasOcpFp8(chipset)) {
auto pred = containsAPred<Float8E5M2Type, Float8E4M3FNType>;
if (llvm::any_of(op->getOperandTypes(), pred) ||
llvm::any_of(op->getResultTypes(), pred)) {
return op->emitOpError("F8E5M2 and F8E4M3FN types are not supported on "
"gfx942 (MI-300) or older chipsets; try "
"F8E5M2FNUZ or F8E4M3FNUZ instead.");
}
}

if (chipset != kGfx942) {
auto pred = containsAPred<Float8E5M2FNUZType, Float8E4M3FNUZType>;
if (llvm::any_of(op->getOperandTypes(), pred) ||
llvm::any_of(op->getResultTypes(), pred)) {
return op->emitOpError(
"F8E5M2FNUZ and F8E4M3FNUZ types are not supported on non-gfx942 "
"(MI-300) chipsets; try F8E5M2 or F8E4M3FN instead.");
}
}
return success();
}

Expand All @@ -108,11 +128,6 @@ struct ConvertToROCDLPass final
void runOnOperation() override {
ModuleOp m = getOperation();

m.walk([&](Operation *op) {
if (failed(validateDataTypes(op)))
return signalPassFailure();
});

if (clROCMIndexingBits != 32 && clROCMIndexingBits != 64) {
m.emitOpError() << "unsupported: ROCm index bit widths must either be "
"64 or 32, got "
Expand Down Expand Up @@ -152,6 +167,16 @@ struct ConvertToROCDLPass final
m.emitOpError() << "Invalid chipset name: " << chipset;
return signalPassFailure();
}
WalkResult allTypesValid = m.walk([&](Operation *op) {
if (failed(validateDataTypes(op, *maybeChipset))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (allTypesValid.wasInterrupted()) {
return signalPassFailure();
}

arith::populateArithToAMDGPUConversionPatterns(
patterns, /*convertFP8Arithmetic=*/true, /*saturateFP8Truncf=*/false,
/*allowPackedF16Rtz=*/false, /*chipset=*/*maybeChipset);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ iree_lit_test_suite(
"config_vector_distribute_reduction_gfx942.mlir",
"config_user_vector_distribute.mlir",
"lowering_scalar_dispatch.mlir",
"pipeline_elementwise_f8fnuz.mlir",
"pipeline_elementwise_f8ocp.mlir",
"pipeline_igemm_tile_and_fuse.mlir",
"pipeline_tile_and_fuse.mlir",
"pipeline_vector_distribute_gfx942.mlir",
Expand All @@ -39,5 +41,6 @@ iree_lit_test_suite(
tools = [
"//tools:iree-opt",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ iree_lit_test_suite(
"config_vector_distribute_gfx942.mlir"
"config_vector_distribute_reduction_gfx942.mlir"
"lowering_scalar_dispatch.mlir"
"pipeline_elementwise_f8fnuz.mlir"
"pipeline_elementwise_f8ocp.mlir"
"pipeline_igemm_tile_and_fuse.mlir"
"pipeline_tile_and_fuse.mlir"
"pipeline_vector_distribute_gfx1100.mlir"
Expand All @@ -31,6 +33,7 @@ iree_lit_test_suite(
TOOLS
FileCheck
iree-opt
not
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" %s | FileCheck %s --check-prefix=CDNA3
// RUN: not iree-opt --split-input-file --iree-gpu-test-target=gfx908 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" -o /dev/null 2>&1 %s | FileCheck %s --check-prefix=ERRORS
// RUN: not iree-opt --split-input-file --iree-gpu-test-target=gfx1201 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" -o /dev/null 2>&1 %s | FileCheck %s --check-prefix=ERRORS

#map = affine_map<(d0) -> (d0)>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
hal.executable @ext_fp8_dispatch {
hal.executable.variant @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export @ext_fp8_dispatch layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index, %arg3 : index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @ext_fp8_dispatch() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096xf8E4M3FNUZ>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096xf8E5M2FNUZ>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4096xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4096xf8E4M3FNUZ>> -> tensor<4096xf8E4M3FNUZ>
%4 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4096xf8E5M2FNUZ>> -> tensor<4096xf8E5M2FNUZ>
%5 = tensor.empty() : tensor<4096xf32>
%6 = linalg.generic {indexing_maps = [#map, #map, #map],
iterator_types = ["parallel"]}
ins(%3, %4 : tensor<4096xf8E4M3FNUZ>, tensor<4096xf8E5M2FNUZ>)
outs(%5 : tensor<4096xf32>) {
^bb0(%in0: f8E4M3FNUZ, %in1: f8E5M2FNUZ, %out: f32):
%7 = arith.extf %in0 : f8E4M3FNUZ to f32
%8 = arith.extf %in1 : f8E5M2FNUZ to f32
%9 = arith.addf %7, %8 : f32
linalg.yield %9 : f32
} -> tensor<4096xf32>
flow.dispatch.tensor.store %6, %2, offsets = [0], sizes = [4096], strides = [1] : tensor<4096xf32> -> !flow.dispatch.tensor<writeonly:tensor<4096xf32>>
return
}
}
}
}

// ERRORS: F8E5M2FNUZ and F8E4M3FNUZ types are not supported on non-gfx942 (MI-300) chipsets; try F8E5M2 or F8E4M3FN instead.

// CDNA3-LABEL: hal.executable public @ext_fp8_dispatch
// CDNA3: hal.executable.variant public @rocm
// CDNA3-COUNT-16: rocdl.cvt.f32.fp8 %{{.*}} : f32
// CDNA3-COUNT-16: rocdl.cvt.f32.bf8 %{{.*}} : f32
// CDNA3: %[[ADD:.+]] = llvm.fadd %{{.*}}, %{{.*}} : vector<16xf32>
// CDNA3: llvm.store %[[ADD]], %{{.*}} : vector<16xf32>, !llvm.ptr<1>

// -----
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1201 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" %s | FileCheck %s --check-prefix=RDNA4
// RUN: not iree-opt --split-input-file --iree-gpu-test-target=gfx942 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" -o /dev/null 2>&1 %s | FileCheck %s --check-prefix=ERRORS
// RUN: not iree-opt --split-input-file --iree-gpu-test-target=gfx908 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" -o /dev/null 2>&1 %s | FileCheck %s --check-prefix=ERRORS

#map = affine_map<(d0) -> (d0)>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
hal.executable @ext_fp8_dispatch {
hal.executable.variant @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export @ext_fp8_dispatch layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index, %arg3 : index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @ext_fp8_dispatch() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096xf8E4M3FN>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096xf8E5M2>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4096xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4096xf8E4M3FN>> -> tensor<4096xf8E4M3FN>
%4 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4096xf8E5M2>> -> tensor<4096xf8E5M2>
%5 = tensor.empty() : tensor<4096xf32>
%6 = linalg.generic {indexing_maps = [#map, #map, #map],
iterator_types = ["parallel"]}
ins(%3, %4 : tensor<4096xf8E4M3FN>, tensor<4096xf8E5M2>)
outs(%5 : tensor<4096xf32>) {
^bb0(%in0: f8E4M3FN, %in1: f8E5M2, %out: f32):
%7 = arith.extf %in0 : f8E4M3FN to f32
%8 = arith.extf %in1 : f8E5M2 to f32
%9 = arith.addf %7, %8 : f32
linalg.yield %9 : f32
} -> tensor<4096xf32>
flow.dispatch.tensor.store %6, %2, offsets = [0], sizes = [4096], strides = [1] : tensor<4096xf32> -> !flow.dispatch.tensor<writeonly:tensor<4096xf32>>
return
}
}
}
}

// ERRORS: F8E5M2 and F8E4M3FN types are not supported on gfx942 (MI-300) or older chipsets; try F8E5M2FNUZ or F8E4M3FNUZ instead.

// RDNA4-LABEL: hal.executable public @ext_fp8_dispatch
// RDNA4: hal.executable.variant public @rocm
// RDNA4-COUNT-16: rocdl.cvt.f32.fp8 %{{.*}} : f32
// RDNA4-COUNT-16: rocdl.cvt.f32.bf8 %{{.*}} : f32
// RDNA4: %[[ADD:.+]] = llvm.fadd %{{.*}}, %{{.*}} : vector<16xf32>
// RDNA4: llvm.store %[[ADD]], %{{.*}} : vector<16xf32>, !llvm.ptr<1>

// -----

0 comments on commit ec128bf

Please sign in to comment.