Skip to content


[LLVMGPU] Add fixes and tests for horizontally fused gemms through GP…
Browse files Browse the repository at this point in the history
…U pipeline. (#19930)

Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar authored Feb 13, 2025
1 parent 04dc4a4 commit eb58f82
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,18 @@ ChangeResult DistributionLayout::resolveWithPossibleConflict(
IRRewriter builder(opOperand.getOwner());
// Handle case where constantOp may have multiple consumers with different
// layouts by creating a copy of constOp for other users.
if (!opOperand.get().hasOneUse() && !vectorLayout &&
if (!opOperand.get().hasOneUse() &&
llvm::isa_and_nonnull<arith::ConstantOp, vector::StepOp>(
opOperand.get().getDefiningOp())) {
Operation *copiedConstOp = builder.clone(*opOperand.get().getDefiningOp());
Value copiedConst = copiedConstOp->getResult(0);
builder.replaceAllUsesExcept(opOperand.get(), copiedConst,
DistributionLayout *newConstLayout =
return ChangeResult::NoChange;

ResolutionResult result = doResolution(rhs);
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,7 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,

// Set anchors at tensor level for vector distribution later and hoist out
// loop invariant anchors.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ iree_lit_test_suite(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ iree_lit_test_suite(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
// RUN: iree-opt --iree-gpu-test-target=gfx942 --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target))" %s --split-input-file | FileCheck %s

func.func @fused_contraction_1(%arg0: tensor<2x4096x640xf16>,
%arg1 : tensor<10x64x640xf16>, %arg2 : tensor<10x64x640xf16>,
%arg3 : tensor<10x64x640xf16>)
-> (tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>) {
%11 = tensor.empty() : tensor<2x10x4096x64xf16>
%12 = tensor.empty() : tensor<2x10x4096x64xf32>
%cst = arith.constant 0.0: f32
%13 = linalg.fill ins(%cst : f32)
outs(%12 : tensor<2x10x4096x64xf32>) -> tensor<2x10x4096x64xf32>
%14:3 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]}
ins(%arg0, %arg1, %arg2, %arg3
: tensor<2x4096x640xf16>, tensor<10x64x640xf16>, tensor<10x64x640xf16>,
outs(%13, %13, %13
: tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>) {
^bb0(%in: f16, %in_0: f16, %in_1: f16, %in_2: f16, %out: f32, %out_3: f32, %out_4: f32):
%18 = arith.extf %in : f16 to f32
%19 = arith.extf %in_0 : f16 to f32
%20 = arith.mulf %18, %19 : f32
%21 = arith.addf %out, %20 : f32
%22 = arith.extf %in_1 : f16 to f32
%23 = arith.mulf %18, %22 : f32
%24 = arith.addf %out_3, %23 : f32
%25 = arith.extf %in_2 : f16 to f32
%26 = arith.mulf %18, %25 : f32
%27 = arith.addf %out_4, %26 : f32
linalg.yield %21, %24, %27 : f32, f32, f32
} -> (tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>)
%15 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%14#0 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) {
^bb0(%in: f32, %out: f16):
%18 = arith.truncf %in : f32 to f16
linalg.yield %18 : f16
} -> tensor<2x10x4096x64xf16>
%16 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%14#1 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) {
^bb0(%in: f32, %out: f16):
%18 = arith.truncf %in : f32 to f16
linalg.yield %18 : f16
} -> tensor<2x10x4096x64xf16>
%17 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%14#2 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) {
^bb0(%in: f32, %out: f16):
%18 = arith.truncf %in : f32 to f16
linalg.yield %18 : f16
} -> tensor<2x10x4096x64xf16>
return %15, %16, %17
: tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>
// CHECK-LABEL: func @fused_contraction_1
// CHECK-COUNT-24: amdgpu.mfma

// -----

func.func @fused_contraction_2(%arg0: tensor<4096x640xf32>,
%arg1 : tensor<640x640xf32>, %arg2 : tensor<640x640xf32>,
%arg3 : tensor<640x640xf32>)
-> (tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>) {
%11 = tensor.empty() : tensor<4096x640xf32>
%12 = tensor.empty() : tensor<4096x640xf32>
%cst = arith.constant 0.0: f32
%13 = linalg.fill ins(%cst : f32)
outs(%12 : tensor<4096x640xf32>) -> tensor<4096x640xf32>
%14:3 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%arg0, %arg1, %arg2, %arg3
: tensor<4096x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>,
outs(%13, %13, %13
: tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>) {
^bb0(%in: f32, %in_0: f32, %in_1: f32, %in_2: f32, %out: f32, %out_3: f32, %out_4: f32):
%20 = arith.mulf %in, %in_0 : f32
%21 = arith.addf %out, %20 : f32
%23 = arith.mulf %in, %in_1 : f32
%24 = arith.addf %out_3, %23 : f32
%26 = arith.mulf %in, %in_2 : f32
%27 = arith.addf %out_4, %26 : f32
linalg.yield %21, %24, %27 : f32, f32, f32
} -> (tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>)
return %14#0, %14#1, %14#2
: tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>
// CHECK-LABEL: func @fused_contraction_2
// CHECK-COUNT-24: amdgpu.mfma

// -----

func.func @fused_contraction_3(%arg0 : tensor<2x4096x640xi8>,
%arg1 : tensor<2x640x640xi8>, %arg2 : tensor<2x640x640xi8>)
-> (tensor<2x4096x640xf16>, tensor<2x4096x640xf16>) {
%c0_i32 = arith.constant 0 : i32
%18 = tensor.empty() : tensor<2x4096x640xf16>
%19 = tensor.empty() : tensor<2x4096x640xi32>
%20 = linalg.fill ins(%c0_i32 : i32)
outs(%19 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
%21:2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
ins(%arg0, %arg1, %arg2 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>, tensor<2x640x640xi8>)
outs(%20, %20 : tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) {
^bb0(%in: i8, %in_0: i8, %in_1: i8, %out: i32, %out_2: i32):
%24 = arith.extsi %in : i8 to i32
%25 = arith.extsi %in_0 : i8 to i32
%26 = arith.muli %24, %25 : i32
%27 = arith.addi %out, %26 : i32
%28 = arith.extsi %in_1 : i8 to i32
%29 = arith.muli %24, %28 : i32
%30 = arith.addi %out_2, %29 : i32
linalg.yield %27, %30 : i32, i32
} -> (tensor<2x4096x640xi32>, tensor<2x4096x640xi32>)
%22 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%21#0 : tensor<2x4096x640xi32>) outs(%18 : tensor<2x4096x640xf16>) {
^bb0(%in: i32, %out: f16):
%27 = arith.sitofp %in : i32 to f32
%29 = arith.truncf %27 : f32 to f16
linalg.yield %29 : f16
} -> tensor<2x4096x640xf16>
%23 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%21#1 : tensor<2x4096x640xi32>) outs(%18 : tensor<2x4096x640xf16>) {
^bb0(%in: i32, %out: f16):
%27 = arith.sitofp %in : i32 to f32
%29 = arith.truncf %27 : f32 to f16
linalg.yield %29 : f16
} -> tensor<2x4096x640xf16>
return %22, %23 : tensor<2x4096x640xf16>, tensor<2x4096x640xf16>
// CHECK-LABEL: func @fused_contraction_3
// CHECK-COUNT-24: amdgpu.mfma

// -----

func.func @fused_contraction_4(%arg0: tensor<2x4096x640xf16>,
%arg1 : tensor<10x64x640xf16>, %arg2 : tensor<10x64x640xf16>,
%arg3 : tensor<10x64x640xf16>)
-> (tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16>) {
%9 = tensor.empty() : tensor<2x10x64x4096xf16>
%10 = tensor.empty() : tensor<2x10x64x4096xf32>
%11 = tensor.empty() : tensor<2x10x4096x64xf16>
%12 = tensor.empty() : tensor<2x10x4096x64xf32>
%cst = arith.constant 0.0: f32
%fill0 = linalg.fill ins(%cst : f32)
outs(%12 : tensor<2x10x4096x64xf32>) -> tensor<2x10x4096x64xf32>
%fill1 = linalg.fill ins(%cst : f32)
outs(%10 : tensor<2x10x64x4096xf32>) -> tensor<2x10x64x4096xf32>
%14:3 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d2)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]}
ins(%arg0, %arg1, %arg2, %arg3
: tensor<2x4096x640xf16>, tensor<10x64x640xf16>, tensor<10x64x640xf16>,
outs(%fill0, %fill0, %fill1
: tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x64x4096xf32>) {
^bb0(%in: f16, %in_0: f16, %in_1: f16, %in_2: f16, %out: f32, %out_3: f32, %out_4: f32):
%18 = arith.extf %in : f16 to f32
%19 = arith.extf %in_0 : f16 to f32
%20 = arith.mulf %18, %19 : f32
%21 = arith.addf %out, %20 : f32
%22 = arith.extf %in_1 : f16 to f32
%23 = arith.mulf %18, %22 : f32
%24 = arith.addf %out_3, %23 : f32
%25 = arith.extf %in_2 : f16 to f32
%26 = arith.mulf %18, %25 : f32
%27 = arith.addf %out_4, %26 : f32
linalg.yield %21, %24, %27 : f32, f32, f32
} -> (tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x64x4096xf32>)
%15 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%14#0 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) {
^bb0(%in: f32, %out: f16):
%18 = arith.truncf %in : f32 to f16
linalg.yield %18 : f16
} -> tensor<2x10x4096x64xf16>
%16 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%14#1 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) {
^bb0(%in: f32, %out: f16):
%18 = arith.truncf %in : f32 to f16
linalg.yield %18 : f16
} -> tensor<2x10x4096x64xf16>
%17 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%14#2 : tensor<2x10x64x4096xf32>) outs(%9 : tensor<2x10x64x4096xf16>) {
^bb0(%in: f32, %out: f16):
%18 = arith.truncf %in : f32 to f16
linalg.yield %18 : f16
} -> tensor<2x10x64x4096xf16>
return %15, %16, %17
: tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16>
// CHECK-LABEL: func @fused_contraction_4
// CHECK-COUNT-24: amdgpu.mfma

0 comments on commit eb58f82

Please sign in to comment.