diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index c447c9635167..1205942dd9a2 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -503,6 +503,45 @@ static bool isNonMatvecContraction(linalg::LinalgOp linalgOp) { getElementCount(contractionDims->n) != 1; } +// To find the number of vector elements per work-item, find a +// bit width that is representative of the computation. +static unsigned getRepresentativeBitWidth(linalg::LinalgOp linalgOp) { + // Check all the inputs with permutation indexing maps. Use + // the maximum of those to get the bit width. + std::optional maxBitWidth; + auto updateElementTypeBitWidth = [&](Value v) { + auto elementType = getElementTypeOrSelf(v); + unsigned bitWidth = elementType.getIntOrFloatBitWidth(); + if (maxBitWidth) { + maxBitWidth = std::max(maxBitWidth.value(), bitWidth); + return; + } + maxBitWidth = bitWidth; + }; + for (OpOperand *input : linalgOp.getDpsInputOperands()) { + AffineMap inputOperandMap = linalgOp.getMatchingIndexingMap(input); + if (!inputOperandMap.isPermutation()) { + continue; + } + updateElementTypeBitWidth(input->get()); + } + if (maxBitWidth) { + return maxBitWidth.value(); + } + + // If none of the operands have permutation inputs, use the result. + // Dont bother about the indexing map. + for (OpOperand &output : linalgOp.getDpsInitsMutable()) { + updateElementTypeBitWidth(output.get()); + } + if (maxBitWidth) { + return maxBitWidth.value(); + } + + // Fall back, just be a word. + return 32; +} + LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint, Operation *op) { @@ -572,6 +611,7 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target, bool vectorizable = projPerm && powTwo; const unsigned minBitwidth = getMinElementBitwidth(linalgOp); + const unsigned representativeBitWidth = getRepresentativeBitWidth(linalgOp); // Make sure we use a tile size that results in some integral number of bytes. const unsigned scaleToByte = std::max(8 / minBitwidth, static_cast(1)); @@ -580,7 +620,7 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target, auto distributeToThreads = [&](int64_t numThreads, std::optional lossFactor = std::nullopt) { - LDBG("Loss factor: " << lossFactor << "\n"); + LDBG("Loss factor: " << lossFactor); // Initialize the configuration. flatWorkgroupSize = 1; // Initialize thread tiling along all partitioned loops with size 1, and @@ -607,13 +647,23 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target, // Ensure vectorization works with the `workgroupTileMultiple`. int64_t workgroupTileMultiple = workgroupTileSizeMultiples[shapeDim]; + unsigned numVectorElements = std::max(4u, 128 / representativeBitWidth); + int64_t vecorizableCandidate = numVectorElements * numThreads; + // For smaller shapes, we reduce `numVectorElements` as we may not find + // work for all threads otherwise and we dont have vectorization enabled + // with loss. + while (vectorizable && (vecorizableCandidate > loopBound) && + numVectorElements > 4) { + numVectorElements /= 2; + vecorizableCandidate = numVectorElements * numThreads; + } vectorizable = - vectorizable && 4 * numThreads % workgroupTileMultiple == 0; - // For the inner most workgroup dim, try to see if we can have 4 - // elements per thread. This enables vectorization. + vectorizable && vecorizableCandidate % workgroupTileMultiple == 0; + if (vectorizable && wgDim == 0 && !lossFactor) { - candidates.push_back(4 * numThreads); + candidates.push_back(vecorizableCandidate); } + // Try all power of two multiples of `workgroupTileMultiple` up to the // subgroup size. uint64_t maxCandidate = @@ -645,17 +695,17 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target, llvm::divideCeil(loopBound, scaledTileSize) <= 2) { continue; } - // Try to let each thread handle 4 elements if this is the workgroup x // dimension. // TODO: Try to take into account element type bit width to get // 4xdword reads instead of 4x{elements}. - if (vectorizable && wgDim == 0 && !lossFactor && candidate % 4 == 0) { + if (vectorizable && wgDim == 0 && !lossFactor && + candidate % numVectorElements == 0) { // Use size-1 vectors to increase parallelism if larger ones causes // idle threads in the subgroup. bool hasIdleThreads = partitionableLoops.size() == 1 && candidate <= subgroupSize; - int vectorSize = hasIdleThreads ? 1 : 4; + int vectorSize = hasIdleThreads ? 1 : numVectorElements; LLVM_DEBUG(llvm::dbgs() << "Use vector size: " << vectorSize << "\n"); threadTileSizes[shapeDim] = vectorSize * scaleToByte; candidateWorkgroupSize = candidate / vectorSize; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir index c769f0424860..602e19502abc 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir @@ -219,6 +219,25 @@ module { // ----- +module { + func.func @elementwise_dynamic_dim_large(%11: tensor, %12: tensor) -> tensor { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %8 = tensor.dim %11, %c0 : tensor + %13 = tensor.empty(%8) : tensor + %15 = linalg.add ins(%11, %12 : tensor, tensor) outs(%13 : tensor) -> tensor + return %15 : tensor + } +} + +// CHECK-LABEL: func.func @elementwise_dynamic_dim_large +// CHECK-SAME: #iree_codegen.translation_info +// CHECK: linalg.add {{.*}}lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: thread = [1, 8] +// CHECK-SAME: workgroup = [1, 512] + +// ----- + module @elementwise_unaligned { func.func @elementwise_unaligned(%11: tensor<180x180xf16>, %12: tensor<180x180xf16>) -> tensor<180x180xf16> { %cst = arith.constant 0.000000e+00 : f32