-
Notifications
You must be signed in to change notification settings - Fork 671
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[AMDGPU] Make fp8 support checks chipset-specific, check shaped types (…
…#20152) With RDNA4, we now have cards that allow the OCP FP8 formats (f8E5M2 and f8E4M3FN). This means that the previous check that caused any module that used those formats to be rejected needs to be relaxed. However, the backend uses the same intrinsics for the gfx942-only FNUZ types and the OCP types. Therefore, the data type validity check needs to be aware of the chipset being targetted in order to validate that the dispatch can be compiled correctly. This patch implements this improved set of checks. In addition, it adds tests for these checks (sadly, we can't used expected-error here because the pipeline scrips debug info) and improves these checks to also look inside vector<> and memref<>. (This also means that there is now an IREE-side mechanism to reject compiling FP8 code pre-gfx942.)
- Loading branch information
Showing
5 changed files
with
151 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
53 changes: 53 additions & 0 deletions
53
compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_elementwise_f8fnuz.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" %s | FileCheck %s --check-prefix=CDNA3 | ||
// RUN: not iree-opt --split-input-file --iree-gpu-test-target=gfx908 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" -o /dev/null 2>&1 %s | FileCheck %s --check-prefix=ERRORS | ||
// RUN: not iree-opt --split-input-file --iree-gpu-test-target=gfx1201 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" -o /dev/null 2>&1 %s | FileCheck %s --check-prefix=ERRORS | ||
|
||
#map = affine_map<(d0) -> (d0)> | ||
#pipeline_layout = #hal.pipeline.layout<bindings = [ | ||
#hal.pipeline.binding<storage_buffer>, | ||
#hal.pipeline.binding<storage_buffer>, | ||
#hal.pipeline.binding<storage_buffer> | ||
]> | ||
hal.executable @ext_fp8_dispatch { | ||
hal.executable.variant @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) { | ||
hal.executable.export @ext_fp8_dispatch layout(#pipeline_layout) { | ||
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index, %arg3 : index): | ||
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 | ||
hal.return %x, %y, %z : index, index, index | ||
} | ||
builtin.module { | ||
func.func @ext_fp8_dispatch() { | ||
%c0 = arith.constant 0 : index | ||
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096xf8E4M3FNUZ>> | ||
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096xf8E5M2FNUZ>> | ||
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4096xf32>> | ||
%3 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4096xf8E4M3FNUZ>> -> tensor<4096xf8E4M3FNUZ> | ||
%4 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4096xf8E5M2FNUZ>> -> tensor<4096xf8E5M2FNUZ> | ||
%5 = tensor.empty() : tensor<4096xf32> | ||
%6 = linalg.generic {indexing_maps = [#map, #map, #map], | ||
iterator_types = ["parallel"]} | ||
ins(%3, %4 : tensor<4096xf8E4M3FNUZ>, tensor<4096xf8E5M2FNUZ>) | ||
outs(%5 : tensor<4096xf32>) { | ||
^bb0(%in0: f8E4M3FNUZ, %in1: f8E5M2FNUZ, %out: f32): | ||
%7 = arith.extf %in0 : f8E4M3FNUZ to f32 | ||
%8 = arith.extf %in1 : f8E5M2FNUZ to f32 | ||
%9 = arith.addf %7, %8 : f32 | ||
linalg.yield %9 : f32 | ||
} -> tensor<4096xf32> | ||
flow.dispatch.tensor.store %6, %2, offsets = [0], sizes = [4096], strides = [1] : tensor<4096xf32> -> !flow.dispatch.tensor<writeonly:tensor<4096xf32>> | ||
return | ||
} | ||
} | ||
} | ||
} | ||
|
||
// ERRORS: F8E5M2FNUZ and F8E4M3FNUZ types are not supported on non-gfx942 (MI-300) chipsets; try F8E5M2 or F8E4M3FN instead. | ||
|
||
// CDNA3-LABEL: hal.executable public @ext_fp8_dispatch | ||
// CDNA3: hal.executable.variant public @rocm | ||
// CDNA3-COUNT-16: rocdl.cvt.f32.fp8 %{{.*}} : f32 | ||
// CDNA3-COUNT-16: rocdl.cvt.f32.bf8 %{{.*}} : f32 | ||
// CDNA3: %[[ADD:.+]] = llvm.fadd %{{.*}}, %{{.*}} : vector<16xf32> | ||
// CDNA3: llvm.store %[[ADD]], %{{.*}} : vector<16xf32>, !llvm.ptr<1> | ||
|
||
// ----- |
53 changes: 53 additions & 0 deletions
53
compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_elementwise_f8ocp.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1201 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" %s | FileCheck %s --check-prefix=RDNA4 | ||
// RUN: not iree-opt --split-input-file --iree-gpu-test-target=gfx942 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" -o /dev/null 2>&1 %s | FileCheck %s --check-prefix=ERRORS | ||
// RUN: not iree-opt --split-input-file --iree-gpu-test-target=gfx908 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-rocdl-pipeline)))" -o /dev/null 2>&1 %s | FileCheck %s --check-prefix=ERRORS | ||
|
||
#map = affine_map<(d0) -> (d0)> | ||
#pipeline_layout = #hal.pipeline.layout<bindings = [ | ||
#hal.pipeline.binding<storage_buffer>, | ||
#hal.pipeline.binding<storage_buffer>, | ||
#hal.pipeline.binding<storage_buffer> | ||
]> | ||
hal.executable @ext_fp8_dispatch { | ||
hal.executable.variant @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) { | ||
hal.executable.export @ext_fp8_dispatch layout(#pipeline_layout) { | ||
^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index, %arg3 : index): | ||
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2, %arg3 | ||
hal.return %x, %y, %z : index, index, index | ||
} | ||
builtin.module { | ||
func.func @ext_fp8_dispatch() { | ||
%c0 = arith.constant 0 : index | ||
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096xf8E4M3FN>> | ||
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096xf8E5M2>> | ||
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4096xf32>> | ||
%3 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4096xf8E4M3FN>> -> tensor<4096xf8E4M3FN> | ||
%4 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4096xf8E5M2>> -> tensor<4096xf8E5M2> | ||
%5 = tensor.empty() : tensor<4096xf32> | ||
%6 = linalg.generic {indexing_maps = [#map, #map, #map], | ||
iterator_types = ["parallel"]} | ||
ins(%3, %4 : tensor<4096xf8E4M3FN>, tensor<4096xf8E5M2>) | ||
outs(%5 : tensor<4096xf32>) { | ||
^bb0(%in0: f8E4M3FN, %in1: f8E5M2, %out: f32): | ||
%7 = arith.extf %in0 : f8E4M3FN to f32 | ||
%8 = arith.extf %in1 : f8E5M2 to f32 | ||
%9 = arith.addf %7, %8 : f32 | ||
linalg.yield %9 : f32 | ||
} -> tensor<4096xf32> | ||
flow.dispatch.tensor.store %6, %2, offsets = [0], sizes = [4096], strides = [1] : tensor<4096xf32> -> !flow.dispatch.tensor<writeonly:tensor<4096xf32>> | ||
return | ||
} | ||
} | ||
} | ||
} | ||
|
||
// ERRORS: F8E5M2 and F8E4M3FN types are not supported on gfx942 (MI-300) or older chipsets; try F8E5M2FNUZ or F8E4M3FNUZ instead. | ||
|
||
// RDNA4-LABEL: hal.executable public @ext_fp8_dispatch | ||
// RDNA4: hal.executable.variant public @rocm | ||
// RDNA4-COUNT-16: rocdl.cvt.f32.fp8 %{{.*}} : f32 | ||
// RDNA4-COUNT-16: rocdl.cvt.f32.bf8 %{{.*}} : f32 | ||
// RDNA4: %[[ADD:.+]] = llvm.fadd %{{.*}}, %{{.*}} : vector<16xf32> | ||
// RDNA4: llvm.store %[[ADD]], %{{.*}} : vector<16xf32>, !llvm.ptr<1> | ||
|
||
// ----- |