diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index 94bb2246bf52..ac0ecf86158a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -490,6 +490,15 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, SmallVector bounds = op.getStaticLoopRanges(); FailureOr contractionDims = mlir::linalg::inferContractionDims(op); + if (failed(contractionDims)) { + assert(IREE::LinalgExt::isaHorizontallyFusedContraction(op) && + "expected horizontally fused contraction op"); + SmallVector indexingMaps; + indexingMaps.push_back(op.getMatchingIndexingMap(op.getDpsInputOperand(0))); + indexingMaps.push_back(op.getMatchingIndexingMap(op.getDpsInputOperand(1))); + indexingMaps.push_back(op.getMatchingIndexingMap(op.getDpsInitOperand(0))); + contractionDims = mlir::linalg::inferContractionDims(indexingMaps); + } assert(succeeded(contractionDims) && "Could not infer contraction dims"); if (contractionDims->k.size() < 1 || contractionDims->m.size() < 1 || @@ -602,6 +611,8 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, /*bestMNTileCountPerSubgroup=*/8, /*bestKTileCountPerSubgroup=*/4}; } + // Scale the seed by number of contractions of horizontally fused case. + seeds.bestMNTileCountPerSubgroup /= op.getNumDpsInputs() - 1; int64_t maxSharedMemoryBytes = target.getWgp().getMaxWorkgroupMemoryBytes(); @@ -699,7 +710,9 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, SmallVector attrs = { NamedAttribute("workgroup", b.getI64ArrayAttr(workgroupTileSizes)), NamedAttribute("reduction", b.getI64ArrayAttr(reductionTileSizes))}; - IREE::GPU::setPromotedOperandList(context, attrs, {0, 1}); + auto promotedOperands = + llvm::to_vector(llvm::seq(op.getNumDpsInputs())); + IREE::GPU::setPromotedOperandList(context, attrs, promotedOperands); IREE::GPU::setMmaKind(context, attrs, mmaKinds[schedule->index]); IREE::GPU::setSubgroupMCount(context, attrs, schedule->mSubgroupCounts[0]); IREE::GPU::setSubgroupNCount(context, attrs, schedule->nSubgroupCounts[0]); @@ -1204,7 +1217,8 @@ setVectorDistributionConfig(IREE::GPU::TargetAttr target, LDBG("VectorDistribution: finding a suitable config..."); if (auto linalgOp = dyn_cast(computeOp)) { - if (linalg::isaContractionOpInterface(linalgOp)) { + if (linalg::isaContractionOpInterface(linalgOp) || + IREE::LinalgExt::isaHorizontallyFusedContraction(linalgOp)) { LDBG("VectorDistribution: trying to find a suitable contraction config"); return setMatmulVectorDistributionConfig(target, entryPoint, linalgOp); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel index 2183e35a75af..909c4ce15c18 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel @@ -29,6 +29,7 @@ iree_lit_test_suite( "cast_address_space_function.mlir", "cast_type_to_fit_mma.mlir", "config_custom_op.mlir", + "config_horizontally_fused_ops.mlir", "config_matvec.mlir", "config_root_op_attribute.mlir", "config_winograd.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt index 53b06322befb..284c966535b5 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt @@ -18,6 +18,7 @@ iree_lit_test_suite( "cast_address_space_function.mlir" "cast_type_to_fit_mma.mlir" "config_custom_op.mlir" + "config_horizontally_fused_ops.mlir" "config_matvec.mlir" "config_root_op_attribute.mlir" "config_winograd.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_horizontally_fused_ops.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_horizontally_fused_ops.mlir new file mode 100644 index 000000000000..87a3d9c18b41 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_horizontally_fused_ops.mlir @@ -0,0 +1,283 @@ +// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --pass-pipeline='builtin.module(iree-llvmgpu-select-lowering-strategy)' --mlir-print-local-scope %s | FileCheck %s + +func.func @fused_contraction_1(%arg0: tensor<2x4096x640xf16>, + %arg1 : tensor<10x64x640xf16>, %arg2 : tensor<10x64x640xf16>, + %arg3 : tensor<10x64x640xf16>) + -> (tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>) { + %11 = tensor.empty() : tensor<2x10x4096x64xf16> + %12 = tensor.empty() : tensor<2x10x4096x64xf32> + %cst = arith.constant 0.0: f32 + %13 = linalg.fill ins(%cst : f32) + outs(%12 : tensor<2x10x4096x64xf32>) -> tensor<2x10x4096x64xf32> + %14:3 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} + ins(%arg0, %arg1, %arg2, %arg3 + : tensor<2x4096x640xf16>, tensor<10x64x640xf16>, tensor<10x64x640xf16>, + tensor<10x64x640xf16>) + outs(%13, %13, %13 + : tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>) { + ^bb0(%in: f16, %in_0: f16, %in_1: f16, %in_2: f16, %out: f32, %out_3: f32, %out_4: f32): + %18 = arith.extf %in : f16 to f32 + %19 = arith.extf %in_0 : f16 to f32 + %20 = arith.mulf %18, %19 : f32 + %21 = arith.addf %out, %20 : f32 + %22 = arith.extf %in_1 : f16 to f32 + %23 = arith.mulf %18, %22 : f32 + %24 = arith.addf %out_3, %23 : f32 + %25 = arith.extf %in_2 : f16 to f32 + %26 = arith.mulf %18, %25 : f32 + %27 = arith.addf %out_4, %26 : f32 + linalg.yield %21, %24, %27 : f32, f32, f32 + } -> (tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>) + %15 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#0 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x4096x64xf16> + %16 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#1 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x4096x64xf16> + %17 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#2 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x4096x64xf16> + return %15, %16, %17 + : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16> +} +// CHECK-LABEL: func @fused_contraction_1 +// CHECK-SAME: translation_info = #iree_codegen.translation_info +// CHECK-SAME: pipeline = LLVMGPUVectorDistribute +// CHECK-SAME: workgroup_size = [256, 1, 1] +// CHECK-SAME: subgroup_size = 64 +// CHECK: %[[GENERIC:.+]]:3 = linalg.generic +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: mma_kind = #iree_gpu.mma_layout, + %arg1 : tensor<640x640xf32>, %arg2 : tensor<640x640xf32>, + %arg3 : tensor<640x640xf32>) + -> (tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>) { + %11 = tensor.empty() : tensor<4096x640xf32> + %12 = tensor.empty() : tensor<4096x640xf32> + %cst = arith.constant 0.0: f32 + %13 = linalg.fill ins(%cst : f32) + outs(%12 : tensor<4096x640xf32>) -> tensor<4096x640xf32> + %14:3 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0, %arg1, %arg2, %arg3 + : tensor<4096x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>, + tensor<640x640xf32>) + outs(%13, %13, %13 + : tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>) { + ^bb0(%in: f32, %in_0: f32, %in_1: f32, %in_2: f32, %out: f32, %out_3: f32, %out_4: f32): + %20 = arith.mulf %in, %in_0 : f32 + %21 = arith.addf %out, %20 : f32 + %23 = arith.mulf %in, %in_1 : f32 + %24 = arith.addf %out_3, %23 : f32 + %26 = arith.mulf %in, %in_2 : f32 + %27 = arith.addf %out_4, %26 : f32 + linalg.yield %21, %24, %27 : f32, f32, f32 + } -> (tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>) + return %14#0, %14#1, %14#2 + : tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32> +} +// CHECK-LABEL: func @fused_contraction_2 +// CHECK-SAME: translation_info = #iree_codegen.translation_info +// CHECK-SAME: pipeline = LLVMGPUVectorDistribute +// CHECK-SAME: workgroup_size = [256, 1, 1] +// CHECK-SAME: subgroup_size = 64 +// CHECK: %[[GENERIC:.+]]:3 = linalg.generic +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: mma_kind = #iree_gpu.mma_layout, + %arg1 : tensor<2x640x640xi8>, %arg2 : tensor<2x640x640xi8>) + -> (tensor<2x4096x640xf16>, tensor<2x4096x640xf16>) { + %c0_i32 = arith.constant 0 : i32 + %18 = tensor.empty() : tensor<2x4096x640xf16> + %19 = tensor.empty() : tensor<2x4096x640xi32> + %20 = linalg.fill ins(%c0_i32 : i32) + outs(%19 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32> + %21:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel", "reduction"]} + ins(%arg0, %arg1, %arg2 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>, tensor<2x640x640xi8>) + outs(%20, %20 : tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) { + ^bb0(%in: i8, %in_0: i8, %in_1: i8, %out: i32, %out_2: i32): + %24 = arith.extsi %in : i8 to i32 + %25 = arith.extsi %in_0 : i8 to i32 + %26 = arith.muli %24, %25 : i32 + %27 = arith.addi %out, %26 : i32 + %28 = arith.extsi %in_1 : i8 to i32 + %29 = arith.muli %24, %28 : i32 + %30 = arith.addi %out_2, %29 : i32 + linalg.yield %27, %30 : i32, i32 + } -> (tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) + %22 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%21#0 : tensor<2x4096x640xi32>) outs(%18 : tensor<2x4096x640xf16>) { + ^bb0(%in: i32, %out: f16): + %27 = arith.sitofp %in : i32 to f32 + %29 = arith.truncf %27 : f32 to f16 + linalg.yield %29 : f16 + } -> tensor<2x4096x640xf16> + %23 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%21#1 : tensor<2x4096x640xi32>) outs(%18 : tensor<2x4096x640xf16>) { + ^bb0(%in: i32, %out: f16): + %27 = arith.sitofp %in : i32 to f32 + %29 = arith.truncf %27 : f32 to f16 + linalg.yield %29 : f16 + } -> tensor<2x4096x640xf16> + return %22, %23 : tensor<2x4096x640xf16>, tensor<2x4096x640xf16> +} +// CHECK-LABEL: func @fused_contraction_3 +// CHECK-SAME: translation_info = #iree_codegen.translation_info +// CHECK-SAME: pipeline = LLVMGPUVectorDistribute +// CHECK-SAME: workgroup_size = [256, 1, 1] +// CHECK-SAME: subgroup_size = 64 +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: mma_kind = #iree_gpu.mma_layout, + %arg1 : tensor<10x64x640xf16>, %arg2 : tensor<10x64x640xf16>, + %arg3 : tensor<10x64x640xf16>) + -> (tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16>) { + %9 = tensor.empty() : tensor<2x10x64x4096xf16> + %10 = tensor.empty() : tensor<2x10x64x4096xf32> + %11 = tensor.empty() : tensor<2x10x4096x64xf16> + %12 = tensor.empty() : tensor<2x10x4096x64xf32> + %cst = arith.constant 0.0: f32 + %fill0 = linalg.fill ins(%cst : f32) + outs(%12 : tensor<2x10x4096x64xf32>) -> tensor<2x10x4096x64xf32> + %fill1 = linalg.fill ins(%cst : f32) + outs(%10 : tensor<2x10x64x4096xf32>) -> tensor<2x10x64x4096xf32> + %14:3 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d2)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} + ins(%arg0, %arg1, %arg2, %arg3 + : tensor<2x4096x640xf16>, tensor<10x64x640xf16>, tensor<10x64x640xf16>, + tensor<10x64x640xf16>) + outs(%fill0, %fill0, %fill1 + : tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x64x4096xf32>) { + ^bb0(%in: f16, %in_0: f16, %in_1: f16, %in_2: f16, %out: f32, %out_3: f32, %out_4: f32): + %18 = arith.extf %in : f16 to f32 + %19 = arith.extf %in_0 : f16 to f32 + %20 = arith.mulf %18, %19 : f32 + %21 = arith.addf %out, %20 : f32 + %22 = arith.extf %in_1 : f16 to f32 + %23 = arith.mulf %18, %22 : f32 + %24 = arith.addf %out_3, %23 : f32 + %25 = arith.extf %in_2 : f16 to f32 + %26 = arith.mulf %18, %25 : f32 + %27 = arith.addf %out_4, %26 : f32 + linalg.yield %21, %24, %27 : f32, f32, f32 + } -> (tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x64x4096xf32>) + %15 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#0 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x4096x64xf16> + %16 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#1 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x4096x64xf16> + %17 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%14#2 : tensor<2x10x64x4096xf32>) outs(%9 : tensor<2x10x64x4096xf16>) { + ^bb0(%in: f32, %out: f16): + %18 = arith.truncf %in : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x10x64x4096xf16> + return %15, %16, %17 + : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16> +} +// CHECK-LABEL: func @fused_contraction_4 +// CHECK-SAME: translation_info = #iree_codegen.translation_info +// CHECK-SAME: pipeline = LLVMGPUVectorDistribute +// CHECK-SAME: workgroup_size = [256, 1, 1] +// CHECK-SAME: subgroup_size = 64 +// CHECK: %[[GENERIC:.+]]:3 = linalg.generic +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: mma_kind = #iree_gpu.mma_layoutgetNumOperands() == 1) { + auto iface = dyn_cast(op); + if (!iface || !iface.hasNoEffect()) + break; + value = op->getOperand(0); + op = value.getDefiningOp(); + } + return value; +} + +struct ContractionOpSequenceArgs { + std::pair operands; + BlockArgument accumulator; +}; +static std::optional +isContractionOpSequence(Value yielded, + function_ref isaPair) { + Operation *reductionOp = yielded.getDefiningOp(); + if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) { + return std::nullopt; + } + + Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0)); + Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1)); + + BlockArgument updated = dyn_cast(reductionRHS); + Value contributed = reductionLHS; + if (!updated) { + updated = dyn_cast(reductionLHS); + if (!updated) { + return std::nullopt; + } + contributed = reductionRHS; + } + contributed = getSourceSkipUnary(contributed); + + Operation *elementwiseOp = contributed.getDefiningOp(); + if (!elementwiseOp || elementwiseOp->getNumResults() != 1 || + elementwiseOp->getNumOperands() != 2) { + return std::nullopt; + } + + if (!isaPair(elementwiseOp, reductionOp)) { + return std::nullopt; + } + + auto elementwiseLHS = dyn_cast_or_null( + getSourceSkipUnary(elementwiseOp->getOperand(0))); + auto elementwiseRHS = dyn_cast_or_null( + getSourceSkipUnary(elementwiseOp->getOperand(1))); + if (!elementwiseLHS || !elementwiseRHS) { + return std::nullopt; + } + + return ContractionOpSequenceArgs{{elementwiseLHS, elementwiseRHS}, updated}; +} + +/// Returns true if the two operations are of the kinds specified by a pair of +/// consecutive template arguments. +template +static bool isPairTemplateImpl(Operation *add, Operation *mul) { + static_assert(sizeof...(Args) % 2 == 0, + "expected an even number of template arguments"); + if (isa(add) && isa(mul)) + return true; + + if constexpr (sizeof...(Args) > 0) + return isPairTemplateImpl(add, mul); + else + return false; +} + +/// Returns true if the block is a body of a contraction with the kinds of +/// operations given pairwise by template arguments. +template +static std::optional +isContractionOpSequence(Value yielded) { + return isContractionOpSequence(yielded, &isPairTemplateImpl); +} + +/// Recognize an operation that is horizontally fused contraction. +/// TODO: The logic below is quite convoluted. Might be better +/// off having a dedicated operation for this. +bool isaHorizontallyFusedContraction(linalg::LinalgOp linalgOp) { + if (linalgOp->getNumResults() == 1) { + return false; + } + // Check that the number of `ins` is one more than the number of results. + if (linalgOp.getNumDpsInputs() != linalgOp->getNumResults() + 1) { + return false; + } + SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); + if (!llvm::all_of(indexingMaps, [](AffineMap m) { + return m.isProjectedPermutation() && !m.isPermutation(); + })) { + return false; + } + + llvm::SetVector rhsArgs; + llvm::SetVector outArgs; + for (auto [index, yieldedVal] : + llvm::enumerate(linalgOp.getBlock()->getTerminator()->getOperands())) { + std::optional args = + isContractionOpSequence(yieldedVal); + if (!args) { + return false; + } + BlockArgument lhs = args->operands.first; + BlockArgument rhs = args->operands.second; + + // One of the block arguments must be argument 0, corresponding to the LHS. + if (lhs.getArgNumber() != 0) { + if (rhs.getArgNumber() != 0) { + return false; + } + std::swap(lhs, rhs); + } + assert(rhs.getArgNumber() != 0 && "cannot have rhs be arg number 0"); + if (rhs.getArgNumber() != index + 1) { + return false; + } + BlockArgument accumulator = args->accumulator; + if (accumulator.getArgNumber() != index + linalgOp.getNumDpsInputs()) { + return false; + } + } + + // Check that they have valid m, n and k dims. + ArrayRef indexingMapsRef(indexingMaps); + AffineMap lhsIndexingMap = indexingMaps.front(); + + auto getResultDims = [](AffineMap m) { + auto r = llvm::map_range(m.getResults(), [](AffineExpr e) { + return cast(e).getPosition(); + }); + return llvm::SmallDenseSet(r.begin(), r.end()); + }; + llvm::SmallDenseSet lhsDims = getResultDims(lhsIndexingMap); + + // Check that all the horizontally fused gemms have common N-dims. M and K + // dims are already known consistent since they are what the LHS has. + std::optional> refNDimsSet; + for (auto [rhsIndexingMap, outputIndexingMap] : + llvm::zip_equal(indexingMapsRef.slice(1, linalgOp.getNumDpsInputs() - 1), + indexingMapsRef.take_back(linalgOp.getNumDpsInits()))) { + llvm::SmallDenseSet rhsDims = getResultDims(rhsIndexingMap); + llvm::SmallDenseSet outsDims = getResultDims(outputIndexingMap); + llvm::SmallDenseSet mDims = lhsDims; + llvm::set_intersect(mDims, outsDims); + if (mDims.empty()) { + return false; + } + llvm::SmallDenseSet nDims = rhsDims; + llvm::set_intersect(nDims, outsDims); + if (nDims.empty()) { + return false; + } + llvm::SmallDenseSet kDims = lhsDims; + llvm::set_intersect(kDims, rhsDims); + if (kDims.empty()) { + return false; + } + + if (refNDimsSet) { + if (!llvm::all_of(nDims, [&](unsigned nDim) { + return refNDimsSet->contains(nDim); + })) { + return false; + } + } else { + refNDimsSet = std::move(nDims); + } + } + return true; +} + } // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h index b9afb32fbd5e..c02699f80c25 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h @@ -211,5 +211,10 @@ bool isBroadcastingOp(linalg::LinalgOp op); /// 2. `linalg.yield` consumes the result of a `tensor.extract_slice` bool isGatherlikeOp(Operation *op); +/// Check if a given operation is a horizontally fused contraction operation. +/// The expectation is that the LHS is common, and all the operands are +/// different RHS. +bool isaHorizontallyFusedContraction(linalg::LinalgOp genericOp); + } // namespace mlir::iree_compiler::IREE::LinalgExt #endif // IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_