From ec128bf6b67e0b1279d6856bad5882c257c9fdbb Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Tue, 4 Mar 2025 19:08:17 -0600 Subject: [PATCH] [AMDGPU] Make fp8 support checks chipset-specific, check shaped types (#20152) With RDNA4, we now have cards that allow the OCP FP8 formats (f8E5M2 and f8E4M3FN). This means that the previous check that caused any module that used those formats to be rejected needs to be relaxed. However, the backend uses the same intrinsics for the gfx942-only FNUZ types and the OCP types. Therefore, the data type validity check needs to be aware of the chipset being targetted in order to validate that the dispatch can be compiled correctly. This patch implements this improved set of checks. In addition, it adds tests for these checks (sadly, we can't used expected-error here because the pipeline scrips debug info) and improves these checks to also look inside vector<> and memref<>. (This also means that there is now an IREE-side mechanism to reject compiling FP8 code pre-gfx942.) --- .../Codegen/LLVMGPU/ConvertToROCDL.cpp | 53 ++++++++++++++----- .../Codegen/LLVMGPU/test/ROCDL/BUILD.bazel | 3 ++ .../Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt | 3 ++ .../ROCDL/pipeline_elementwise_f8fnuz.mlir | 53 +++++++++++++++++++ .../ROCDL/pipeline_elementwise_f8ocp.mlir | 53 +++++++++++++++++++ 5 files changed, 151 insertions(+), 14 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_elementwise_f8fnuz.mlir create mode 100644 compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_elementwise_f8ocp.mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp index 60355b8f0db0..daef47e2ee96 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp @@ -75,18 +75,38 @@ static void populateConvertGPUToAMDGPUPatterns(RewritePatternSet &patterns) { } // namespace +template +static bool containsAPred(Type type) { + type = getElementTypeOrSelf(type); + return llvm::isa(type); +} + // Function to check valid data types on the ROCm backend. -static LogicalResult validateDataTypes(Operation *op) { - auto operandTypes = llvm::to_vector(op->getOperandTypes()); - auto resultTypes = llvm::to_vector(op->getResultTypes()); - if (llvm::any_of(llvm::concat(operandTypes, resultTypes), - llvm::IsaPred)) { - op->emitOpError() - << "F8E5M2 and F8E4M3FN types are not supported on " - "the ROCm backend; try F8E5M2FNUZ or F8E4M3FNUZ instead."; - return failure(); +// Note to readers: different chips take different FP8 formats but re-use the +// same instruction and intrinsic names, so we must filter out the "wrong" FP8 +// here. +static LogicalResult validateDataTypes(Operation *op, + const amdgpu::Chipset &chipset) { + constexpr amdgpu::Chipset kGfx942 = amdgpu::Chipset(9, 4, 2); + if (!amdgpu::hasOcpFp8(chipset)) { + auto pred = containsAPred; + if (llvm::any_of(op->getOperandTypes(), pred) || + llvm::any_of(op->getResultTypes(), pred)) { + return op->emitOpError("F8E5M2 and F8E4M3FN types are not supported on " + "gfx942 (MI-300) or older chipsets; try " + "F8E5M2FNUZ or F8E4M3FNUZ instead."); + } } + if (chipset != kGfx942) { + auto pred = containsAPred; + if (llvm::any_of(op->getOperandTypes(), pred) || + llvm::any_of(op->getResultTypes(), pred)) { + return op->emitOpError( + "F8E5M2FNUZ and F8E4M3FNUZ types are not supported on non-gfx942 " + "(MI-300) chipsets; try F8E5M2 or F8E4M3FN instead."); + } + } return success(); } @@ -108,11 +128,6 @@ struct ConvertToROCDLPass final void runOnOperation() override { ModuleOp m = getOperation(); - m.walk([&](Operation *op) { - if (failed(validateDataTypes(op))) - return signalPassFailure(); - }); - if (clROCMIndexingBits != 32 && clROCMIndexingBits != 64) { m.emitOpError() << "unsupported: ROCm index bit widths must either be " "64 or 32, got " @@ -152,6 +167,16 @@ struct ConvertToROCDLPass final m.emitOpError() << "Invalid chipset name: " << chipset; return signalPassFailure(); } + WalkResult allTypesValid = m.walk([&](Operation *op) { + if (failed(validateDataTypes(op, *maybeChipset))) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (allTypesValid.wasInterrupted()) { + return signalPassFailure(); + } + arith::populateArithToAMDGPUConversionPatterns( patterns, /*convertFP8Arithmetic=*/true, /*saturateFP8Truncf=*/false, /*allowPackedF16Rtz=*/false, /*chipset=*/*maybeChipset); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel index 77b36bcc116b..6cc6f300627d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel @@ -26,6 +26,8 @@ iree_lit_test_suite( "config_vector_distribute_reduction_gfx942.mlir", "config_user_vector_distribute.mlir", "lowering_scalar_dispatch.mlir", + "pipeline_elementwise_f8fnuz.mlir", + "pipeline_elementwise_f8ocp.mlir", "pipeline_igemm_tile_and_fuse.mlir", "pipeline_tile_and_fuse.mlir", "pipeline_vector_distribute_gfx942.mlir", @@ -39,5 +41,6 @@ iree_lit_test_suite( tools = [ "//tools:iree-opt", "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", ], ) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt index 56891c8f5d93..6732627e1cd8 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt @@ -22,6 +22,8 @@ iree_lit_test_suite( "config_vector_distribute_gfx942.mlir" "config_vector_distribute_reduction_gfx942.mlir" "lowering_scalar_dispatch.mlir" + "pipeline_elementwise_f8fnuz.mlir" + "pipeline_elementwise_f8ocp.mlir" "pipeline_igemm_tile_and_fuse.mlir" "pipeline_tile_and_fuse.mlir" "pipeline_vector_distribute_gfx1100.mlir" @@ -31,6 +33,7 @@ iree_lit_test_suite( TOOLS FileCheck iree-opt + not ) ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_elementwise_f8fnuz.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_elementwise_f8fnuz.mlir new file mode 100644 index 000000000000..363fdc9ca00d --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_elementwise_f8fnuz.mlir @@ -0,0 +1,53 @@ +// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" %s | FileCheck %s --check-prefix=CDNA3 +// RUN: not iree-opt --split-input-file --iree-gpu-test-target=gfx908 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" -o /dev/null 2>&1 %s | FileCheck %s --check-prefix=ERRORS +// RUN: not iree-opt --split-input-file --iree-gpu-test-target=gfx1201 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" -o /dev/null 2>&1 %s | FileCheck %s --check-prefix=ERRORS + +#map = affine_map<(d0) -> (d0)> +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +hal.executable @ext_fp8_dispatch { + hal.executable.variant @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export @ext_fp8_dispatch layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index, %arg3 : index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @ext_fp8_dispatch() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor> -> tensor<4096xf8E4M3FNUZ> + %4 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor> -> tensor<4096xf8E5M2FNUZ> + %5 = tensor.empty() : tensor<4096xf32> + %6 = linalg.generic {indexing_maps = [#map, #map, #map], + iterator_types = ["parallel"]} + ins(%3, %4 : tensor<4096xf8E4M3FNUZ>, tensor<4096xf8E5M2FNUZ>) + outs(%5 : tensor<4096xf32>) { + ^bb0(%in0: f8E4M3FNUZ, %in1: f8E5M2FNUZ, %out: f32): + %7 = arith.extf %in0 : f8E4M3FNUZ to f32 + %8 = arith.extf %in1 : f8E5M2FNUZ to f32 + %9 = arith.addf %7, %8 : f32 + linalg.yield %9 : f32 + } -> tensor<4096xf32> + flow.dispatch.tensor.store %6, %2, offsets = [0], sizes = [4096], strides = [1] : tensor<4096xf32> -> !flow.dispatch.tensor> + return + } + } + } +} + +// ERRORS: F8E5M2FNUZ and F8E4M3FNUZ types are not supported on non-gfx942 (MI-300) chipsets; try F8E5M2 or F8E4M3FN instead. + +// CDNA3-LABEL: hal.executable public @ext_fp8_dispatch +// CDNA3: hal.executable.variant public @rocm +// CDNA3-COUNT-16: rocdl.cvt.f32.fp8 %{{.*}} : f32 +// CDNA3-COUNT-16: rocdl.cvt.f32.bf8 %{{.*}} : f32 +// CDNA3: %[[ADD:.+]] = llvm.fadd %{{.*}}, %{{.*}} : vector<16xf32> +// CDNA3: llvm.store %[[ADD]], %{{.*}} : vector<16xf32>, !llvm.ptr<1> + +// ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_elementwise_f8ocp.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_elementwise_f8ocp.mlir new file mode 100644 index 000000000000..89a4b9bd7c92 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_elementwise_f8ocp.mlir @@ -0,0 +1,53 @@ +// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1201 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" %s | FileCheck %s --check-prefix=RDNA4 +// RUN: not iree-opt --split-input-file --iree-gpu-test-target=gfx942 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" -o /dev/null 2>&1 %s | FileCheck %s --check-prefix=ERRORS +// RUN: not iree-opt --split-input-file --iree-gpu-test-target=gfx908 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" -o /dev/null 2>&1 %s | FileCheck %s --check-prefix=ERRORS + +#map = affine_map<(d0) -> (d0)> +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +hal.executable @ext_fp8_dispatch { + hal.executable.variant @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export @ext_fp8_dispatch layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index, %arg3 : index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @ext_fp8_dispatch() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor> -> tensor<4096xf8E4M3FN> + %4 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor> -> tensor<4096xf8E5M2> + %5 = tensor.empty() : tensor<4096xf32> + %6 = linalg.generic {indexing_maps = [#map, #map, #map], + iterator_types = ["parallel"]} + ins(%3, %4 : tensor<4096xf8E4M3FN>, tensor<4096xf8E5M2>) + outs(%5 : tensor<4096xf32>) { + ^bb0(%in0: f8E4M3FN, %in1: f8E5M2, %out: f32): + %7 = arith.extf %in0 : f8E4M3FN to f32 + %8 = arith.extf %in1 : f8E5M2 to f32 + %9 = arith.addf %7, %8 : f32 + linalg.yield %9 : f32 + } -> tensor<4096xf32> + flow.dispatch.tensor.store %6, %2, offsets = [0], sizes = [4096], strides = [1] : tensor<4096xf32> -> !flow.dispatch.tensor> + return + } + } + } +} + +// ERRORS: F8E5M2 and F8E4M3FN types are not supported on gfx942 (MI-300) or older chipsets; try F8E5M2FNUZ or F8E4M3FNUZ instead. + +// RDNA4-LABEL: hal.executable public @ext_fp8_dispatch +// RDNA4: hal.executable.variant public @rocm +// RDNA4-COUNT-16: rocdl.cvt.f32.fp8 %{{.*}} : f32 +// RDNA4-COUNT-16: rocdl.cvt.f32.bf8 %{{.*}} : f32 +// RDNA4: %[[ADD:.+]] = llvm.fadd %{{.*}}, %{{.*}} : vector<16xf32> +// RDNA4: llvm.store %[[ADD]], %{{.*}} : vector<16xf32>, !llvm.ptr<1> + +// -----