From 93eb7c803851c74eb92f58b67b8973e722e85b94 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Mon, 10 Feb 2025 20:32:20 -0800 Subject: [PATCH] [LLVMGPU] Add initial kernel config for horizontally fused gemms. (#19923) This is in preparation of the modified way of generating horizontally fused GEMMs. This PR adds kernel configuration for these GEMM ops to allow them to go down the vector distribute pipeline. --------- Signed-off-by: MaheshRavishankar --- .../compiler/Codegen/LLVMGPU/KernelConfig.cpp | 18 +- .../compiler/Codegen/LLVMGPU/test/BUILD.bazel | 1 + .../Codegen/LLVMGPU/test/CMakeLists.txt | 1 + .../test/config_horizontally_fused_ops.mlir | 283 ++++++++++++++++++ .../Dialect/LinalgExt/Utils/BUILD.bazel | 1 + .../Dialect/LinalgExt/Utils/CMakeLists.txt | 1 + .../Dialect/LinalgExt/Utils/Utils.cpp | 191 ++++++++++++ .../compiler/Dialect/LinalgExt/Utils/Utils.h | 5 + 8 files changed, 499 insertions(+), 2 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_horizontally_fused_ops.mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index 94bb2246bf52..ac0ecf86158a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -490,6 +490,15 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, SmallVector bounds = op.getStaticLoopRanges(); FailureOr contractionDims = mlir::linalg::inferContractionDims(op); + if (failed(contractionDims)) { + assert(IREE::LinalgExt::isaHorizontallyFusedContraction(op) && + "expected horizontally fused contraction op"); + SmallVector indexingMaps; + indexingMaps.push_back(op.getMatchingIndexingMap(op.getDpsInputOperand(0))); + indexingMaps.push_back(op.getMatchingIndexingMap(op.getDpsInputOperand(1))); + indexingMaps.push_back(op.getMatchingIndexingMap(op.getDpsInitOperand(0))); + contractionDims = mlir::linalg::inferContractionDims(indexingMaps); + } assert(succeeded(contractionDims) && "Could not infer contraction dims"); if (contractionDims->k.size() < 1 || contractionDims->m.size() < 1 || @@ -602,6 +611,8 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, /*bestMNTileCountPerSubgroup=*/8, /*bestKTileCountPerSubgroup=*/4}; } + // Scale the seed by number of contractions of horizontally fused case. + seeds.bestMNTileCountPerSubgroup /= op.getNumDpsInputs() - 1; int64_t maxSharedMemoryBytes = target.getWgp().getMaxWorkgroupMemoryBytes(); @@ -699,7 +710,9 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, SmallVector attrs = { NamedAttribute("workgroup", b.getI64ArrayAttr(workgroupTileSizes)), NamedAttribute("reduction", b.getI64ArrayAttr(reductionTileSizes))}; - IREE::GPU::setPromotedOperandList(context, attrs, {0, 1}); + auto promotedOperands = + llvm::to_vector(llvm::seq(op.getNumDpsInputs())); + IREE::GPU::setPromotedOperandList(context, attrs, promotedOperands); IREE::GPU::setMmaKind(context, attrs, mmaKinds[schedule->index]); IREE::GPU::setSubgroupMCount(context, attrs, schedule->mSubgroupCounts[0]); IREE::GPU::setSubgroupNCount(context, attrs, schedule->nSubgroupCounts[0]); @@ -1204,7 +1217,8 @@ setVectorDistributionConfig(IREE::GPU::TargetAttr target, LDBG("VectorDistribution: finding a suitable config..."); if (auto linalgOp = dyn_cast(computeOp)) { - if (linalg::isaContractionOpInterface(linalgOp)) { + if (linalg::isaContractionOpInterface(linalgOp) || + IREE::LinalgExt::isaHorizontallyFusedContraction(linalgOp)) { LDBG("VectorDistribution: trying to find a suitable contraction config"); return setMatmulVectorDistributionConfig(target, entryPoint, linalgOp); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel index 2183e35a75af..909c4ce15c18 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel @@ -29,6 +29,7 @@ iree_lit_test_suite( "cast_address_space_function.mlir", "cast_type_to_fit_mma.mlir", "config_custom_op.mlir", + "config_horizontally_fused_ops.mlir", "config_matvec.mlir", "config_root_op_attribute.mlir", "config_winograd.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt index 53b06322befb..284c966535b5 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt @@ -18,6 +18,7 @@ iree_lit_test_suite( "cast_address_space_function.mlir" "cast_type_to_fit_mma.mlir" "config_custom_op.mlir" + "config_horizontally_fused_ops.mlir" "config_matvec.mlir" "config_root_op_attribute.mlir" "config_winograd.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_horizontally_fused_ops.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_horizontally_fused_ops.mlir new file mode 100644 index 000000000000..87a3d9c18b41 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_horizontally_fused_ops.mlir @@ -0,0 +1,283 @@ +// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --pass-pipeline='builtin.module(iree-llvmgpu-select-lowering-strategy)' --mlir-print-local-scope %s | FileCheck %s + +func.func @fused_contraction_1(%arg0: tensor<2x4096x640xf16>, + %arg1 : tensor<10x64x640xf16>, %arg2 : tensor<10x64x640xf16>, + %arg3 : tensor<10x64x640xf16>) + -> (tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>) { + %11 = tensor.empty() : tensor<2x10x4096x64xf16> + %12 = tensor.empty() : tensor<2x10x4096x64xf32> + %cst = arith.constant 0.0: f32 + %13 = linalg.fill ins(%cst : f32) + outs(%12 : tensor<2x10x4096x64xf32>) -> tensor<2x10x4096x64xf32> + %14:3 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} + ins(%arg0, %arg1, %arg2, %arg3 + : tensor<2x4096x640xf16>, tensor<10x64x640xf16>, tensor<10x64x640xf16>, + tensor<10x64x640xf16>) + outs(%13, %13, %13 + : tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>) { + ^bb0(%in: f16, %in_0: f16, %in_1: f16, %in_2: f16, %out: f32, %out_3: f32, %out_4: f32): + %18 = arith.extf %in : f16 to f32 + %19 = arith.extf %in_0 : f16 to f32 + %20 = arith.mulf %18, %19 : f32 + %21 = arith.addf %out, %20 : f32 + %22 = arith.extf %in_1 : f16 to f32 + %23 = arith.mulf %18, %22 : f32 + %24 = arith.addf %out_3, %23 : f32 + %25 = arith.extf %in_2 : f16 to f32 + %26 = arith.mulf %18, %25 : f32 + %27 = arith.addf %out_4, %26 : f32 + linalg.yield %21, %24, %27 : f32, f32, f32 + } -> (tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>) + %15 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#0 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x4096x64xf16> + %16 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#1 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x4096x64xf16> + %17 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#2 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x4096x64xf16> + return %15, %16, %17 + : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16> +} +// CHECK-LABEL: func @fused_contraction_1 +// CHECK-SAME: translation_info = #iree_codegen.translation_info +// CHECK-SAME: pipeline = LLVMGPUVectorDistribute +// CHECK-SAME: workgroup_size = [256, 1, 1] +// CHECK-SAME: subgroup_size = 64 +// CHECK: %[[GENERIC:.+]]:3 = linalg.generic +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: mma_kind = #iree_gpu.mma_layout, + %arg1 : tensor<640x640xf32>, %arg2 : tensor<640x640xf32>, + %arg3 : tensor<640x640xf32>) + -> (tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>) { + %11 = tensor.empty() : tensor<4096x640xf32> + %12 = tensor.empty() : tensor<4096x640xf32> + %cst = arith.constant 0.0: f32 + %13 = linalg.fill ins(%cst : f32) + outs(%12 : tensor<4096x640xf32>) -> tensor<4096x640xf32> + %14:3 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0, %arg1, %arg2, %arg3 + : tensor<4096x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>, + tensor<640x640xf32>) + outs(%13, %13, %13 + : tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>) { + ^bb0(%in: f32, %in_0: f32, %in_1: f32, %in_2: f32, %out: f32, %out_3: f32, %out_4: f32): + %20 = arith.mulf %in, %in_0 : f32 + %21 = arith.addf %out, %20 : f32 + %23 = arith.mulf %in, %in_1 : f32 + %24 = arith.addf %out_3, %23 : f32 + %26 = arith.mulf %in, %in_2 : f32 + %27 = arith.addf %out_4, %26 : f32 + linalg.yield %21, %24, %27 : f32, f32, f32 + } -> (tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>) + return %14#0, %14#1, %14#2 + : tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32> +} +// CHECK-LABEL: func @fused_contraction_2 +// CHECK-SAME: translation_info = #iree_codegen.translation_info +// CHECK-SAME: pipeline = LLVMGPUVectorDistribute +// CHECK-SAME: workgroup_size = [256, 1, 1] +// CHECK-SAME: subgroup_size = 64 +// CHECK: %[[GENERIC:.+]]:3 = linalg.generic +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: mma_kind = #iree_gpu.mma_layout, + %arg1 : tensor<2x640x640xi8>, %arg2 : tensor<2x640x640xi8>) + -> (tensor<2x4096x640xf16>, tensor<2x4096x640xf16>) { + %c0_i32 = arith.constant 0 : i32 + %18 = tensor.empty() : tensor<2x4096x640xf16> + %19 = tensor.empty() : tensor<2x4096x640xi32> + %20 = linalg.fill ins(%c0_i32 : i32) + outs(%19 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32> + %21:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel", "reduction"]} + ins(%arg0, %arg1, %arg2 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>, tensor<2x640x640xi8>) + outs(%20, %20 : tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) { + ^bb0(%in: i8, %in_0: i8, %in_1: i8, %out: i32, %out_2: i32): + %24 = arith.extsi %in : i8 to i32 + %25 = arith.extsi %in_0 : i8 to i32 + %26 = arith.muli %24, %25 : i32 + %27 = arith.addi %out, %26 : i32 + %28 = arith.extsi %in_1 : i8 to i32 + %29 = arith.muli %24, %28 : i32 + %30 = arith.addi %out_2, %29 : i32 + linalg.yield %27, %30 : i32, i32 + } -> (tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) + %22 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%21#0 : tensor<2x4096x640xi32>) outs(%18 : tensor<2x4096x640xf16>) { + ^bb0(%in: i32, %out: f16): + %27 = arith.sitofp %in : i32 to f32 + %29 = arith.truncf %27 : f32 to f16 + linalg.yield %29 : f16 + } -> tensor<2x4096x640xf16> + %23 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%21#1 : tensor<2x4096x640xi32>) outs(%18 : tensor<2x4096x640xf16>) { + ^bb0(%in: i32, %out: f16): + %27 = arith.sitofp %in : i32 to f32 + %29 = arith.truncf %27 : f32 to f16 + linalg.yield %29 : f16 + } -> tensor<2x4096x640xf16> + return %22, %23 : tensor<2x4096x640xf16>, tensor<2x4096x640xf16> +} +// CHECK-LABEL: func @fused_contraction_3 +// CHECK-SAME: translation_info = #iree_codegen.translation_info +// CHECK-SAME: pipeline = LLVMGPUVectorDistribute +// CHECK-SAME: workgroup_size = [256, 1, 1] +// CHECK-SAME: subgroup_size = 64 +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: mma_kind = #iree_gpu.mma_layout, + %arg1 : tensor<10x64x640xf16>, %arg2 : tensor<10x64x640xf16>, + %arg3 : tensor<10x64x640xf16>) + -> (tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16>) { + %9 = tensor.empty() : tensor<2x10x64x4096xf16> + %10 = tensor.empty() : tensor<2x10x64x4096xf32> + %11 = tensor.empty() : tensor<2x10x4096x64xf16> + %12 = tensor.empty() : tensor<2x10x4096x64xf32> + %cst = arith.constant 0.0: f32 + %fill0 = linalg.fill ins(%cst : f32) + outs(%12 : tensor<2x10x4096x64xf32>) -> tensor<2x10x4096x64xf32> + %fill1 = linalg.fill ins(%cst : f32) + outs(%10 : tensor<2x10x64x4096xf32>) -> tensor<2x10x64x4096xf32> + %14:3 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d2)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} + ins(%arg0, %arg1, %arg2, %arg3 + : tensor<2x4096x640xf16>, tensor<10x64x640xf16>, tensor<10x64x640xf16>, + tensor<10x64x640xf16>) + outs(%fill0, %fill0, %fill1 + : tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x64x4096xf32>) { + ^bb0(%in: f16, %in_0: f16, %in_1: f16, %in_2: f16, %out: f32, %out_3: f32, %out_4: f32): + %18 = arith.extf %in : f16 to f32 + %19 = arith.extf %in_0 : f16 to f32 + %20 = arith.mulf %18, %19 : f32 + %21 = arith.addf %out, %20 : f32 + %22 = arith.extf %in_1 : f16 to f32 + %23 = arith.mulf %18, %22 : f32 + %24 = arith.addf %out_3, %23 : f32 + %25 = arith.extf %in_2 : f16 to f32 + %26 = arith.mulf %18, %25 : f32 + %27 = arith.addf %out_4, %26 : f32 + linalg.yield %21, %24, %27 : f32, f32, f32 + } -> (tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x64x4096xf32>) + %15 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#0 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x4096x64xf16> + %16 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#1 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x4096x64xf16> + %17 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#2 : tensor<2x10x64x4096xf32>) outs(%9 : tensor<2x10x64x4096xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x64x4096xf16> + return %15, %16, %17 + : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16> +} +// CHECK-LABEL: func @fused_contraction_4 +// CHECK-SAME: translation_info = #iree_codegen.translation_info +// CHECK-SAME: pipeline = LLVMGPUVectorDistribute +// CHECK-SAME: workgroup_size = [256, 1, 1] +// CHECK-SAME: subgroup_size = 64 +// CHECK: %[[GENERIC:.+]]:3 = linalg.generic +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: mma_kind = #iree_gpu.mma_layoutgetNumOperands() == 1) { + auto iface = dyn_cast(op); + if (!iface || !iface.hasNoEffect()) + break; + value = op->getOperand(0); + op = value.getDefiningOp(); + } + return value; +} + +struct ContractionOpSequenceArgs { + std::pair operands; + BlockArgument accumulator; +}; +static std::optional +isContractionOpSequence(Value yielded, + function_ref isaPair) { + Operation *reductionOp = yielded.getDefiningOp(); + if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) { + return std::nullopt; + } + + Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0)); + Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1)); + + BlockArgument updated = dyn_cast(reductionRHS); + Value contributed = reductionLHS; + if (!updated) { + updated = dyn_cast(reductionLHS); + if (!updated) { + return std::nullopt; + } + contributed = reductionRHS; + } + contributed = getSourceSkipUnary(contributed); + + Operation *elementwiseOp = contributed.getDefiningOp(); + if (!elementwiseOp || elementwiseOp->getNumResults() != 1 || + elementwiseOp->getNumOperands() != 2) { + return std::nullopt; + } + + if (!isaPair(elementwiseOp, reductionOp)) { + return std::nullopt; + } + + auto elementwiseLHS = dyn_cast_or_null( + getSourceSkipUnary(elementwiseOp->getOperand(0))); + auto elementwiseRHS = dyn_cast_or_null( + getSourceSkipUnary(elementwiseOp->getOperand(1))); + if (!elementwiseLHS || !elementwiseRHS) { + return std::nullopt; + } + + return ContractionOpSequenceArgs{{elementwiseLHS, elementwiseRHS}, updated}; +} + +/// Returns true if the two operations are of the kinds specified by a pair of +/// consecutive template arguments. +template +static bool isPairTemplateImpl(Operation *add, Operation *mul) { + static_assert(sizeof...(Args) % 2 == 0, + "expected an even number of template arguments"); + if (isa(add) && isa(mul)) + return true; + + if constexpr (sizeof...(Args) > 0) + return isPairTemplateImpl(add, mul); + else + return false; +} + +/// Returns true if the block is a body of a contraction with the kinds of +/// operations given pairwise by template arguments. +template +static std::optional +isContractionOpSequence(Value yielded) { + return isContractionOpSequence(yielded, &isPairTemplateImpl); +} + +/// Recognize an operation that is horizontally fused contraction. +/// TODO: The logic below is quite convoluted. Might be better +/// off having a dedicated operation for this. +bool isaHorizontallyFusedContraction(linalg::LinalgOp linalgOp) { + if (linalgOp->getNumResults() == 1) { + return false; + } + // Check that the number of `ins` is one more than the number of results. + if (linalgOp.getNumDpsInputs() != linalgOp->getNumResults() + 1) { + return false; + } + SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); + if (!llvm::all_of(indexingMaps, [](AffineMap m) { + return m.isProjectedPermutation() && !m.isPermutation(); + })) { + return false; + } + + llvm::SetVector rhsArgs; + llvm::SetVector outArgs; + for (auto [index, yieldedVal] : + llvm::enumerate(linalgOp.getBlock()->getTerminator()->getOperands())) { + std::optional args = + isContractionOpSequence(yieldedVal); + if (!args) { + return false; + } + BlockArgument lhs = args->operands.first; + BlockArgument rhs = args->operands.second; + + // One of the block arguments must be argument 0, corresponding to the LHS. + if (lhs.getArgNumber() != 0) { + if (rhs.getArgNumber() != 0) { + return false; + } + std::swap(lhs, rhs); + } + assert(rhs.getArgNumber() != 0 && "cannot have rhs be arg number 0"); + if (rhs.getArgNumber() != index + 1) { + return false; + } + BlockArgument accumulator = args->accumulator; + if (accumulator.getArgNumber() != index + linalgOp.getNumDpsInputs()) { + return false; + } + } + + // Check that they have valid m, n and k dims. + ArrayRef indexingMapsRef(indexingMaps); + AffineMap lhsIndexingMap = indexingMaps.front(); + + auto getResultDims = [](AffineMap m) { + auto r = llvm::map_range(m.getResults(), [](AffineExpr e) { + return cast(e).getPosition(); + }); + return llvm::SmallDenseSet(r.begin(), r.end()); + }; + llvm::SmallDenseSet lhsDims = getResultDims(lhsIndexingMap); + + // Check that all the horizontally fused gemms have common N-dims. M and K + // dims are already known consistent since they are what the LHS has. + std::optional> refNDimsSet; + for (auto [rhsIndexingMap, outputIndexingMap] : + llvm::zip_equal(indexingMapsRef.slice(1, linalgOp.getNumDpsInputs() - 1), + indexingMapsRef.take_back(linalgOp.getNumDpsInits()))) { + llvm::SmallDenseSet rhsDims = getResultDims(rhsIndexingMap); + llvm::SmallDenseSet outsDims = getResultDims(outputIndexingMap); + llvm::SmallDenseSet mDims = lhsDims; + llvm::set_intersect(mDims, outsDims); + if (mDims.empty()) { + return false; + } + llvm::SmallDenseSet nDims = rhsDims; + llvm::set_intersect(nDims, outsDims); + if (nDims.empty()) { + return false; + } + llvm::SmallDenseSet kDims = lhsDims; + llvm::set_intersect(kDims, rhsDims); + if (kDims.empty()) { + return false; + } + + if (refNDimsSet) { + if (!llvm::all_of(nDims, [&](unsigned nDim) { + return refNDimsSet->contains(nDim); + })) { + return false; + } + } else { + refNDimsSet = std::move(nDims); + } + } + return true; +} + } // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h index b9afb32fbd5e..c02699f80c25 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h @@ -211,5 +211,10 @@ bool isBroadcastingOp(linalg::LinalgOp op); /// 2. `linalg.yield` consumes the result of a `tensor.extract_slice` bool isGatherlikeOp(Operation *op); +/// Check if a given operation is a horizontally fused contraction operation. +/// The expectation is that the LHS is common, and all the operands are +/// different RHS. +bool isaHorizontallyFusedContraction(linalg::LinalgOp genericOp); + } // namespace mlir::iree_compiler::IREE::LinalgExt #endif // IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_