Skip to content

Commit

Permalink
[GPU] Take element type bitwidth into account with vector size.
Browse files Browse the repository at this point in the history
Co-authored-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: Nirvedh <[email protected]>
  • Loading branch information
MaheshRavishankar authored and nirvedhmeshram committed Feb 13, 2025
1 parent 8fab35c commit 97c9fcf
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,45 @@ static bool isNonMatvecContraction(linalg::LinalgOp linalgOp) {
getElementCount(contractionDims->n) != 1;
}

// To find the number of vector elements per work-item, find a
// bit width that is representative of the computation.
static unsigned getRepresentativeBitWidth(linalg::LinalgOp linalgOp) {
// Check all the inputs with permutation indexing maps. Use
// the maximum of those to get the bit width.
std::optional<unsigned> maxBitWidth;
auto updateElementTypeBitWidth = [&](Value v) {
auto elementType = getElementTypeOrSelf(v);
unsigned bitWidth = elementType.getIntOrFloatBitWidth();
if (maxBitWidth) {
maxBitWidth = std::max(maxBitWidth.value(), bitWidth);
return;
}
maxBitWidth = bitWidth;
};
for (OpOperand *input : linalgOp.getDpsInputOperands()) {
AffineMap inputOperandMap = linalgOp.getMatchingIndexingMap(input);
if (!inputOperandMap.isPermutation()) {
continue;
}
updateElementTypeBitWidth(input->get());
}
if (maxBitWidth) {
return maxBitWidth.value();
}

// If none of the operands have permutation inputs, use the result.
// Dont bother about the indexing map.
for (OpOperand &output : linalgOp.getDpsInitsMutable()) {
updateElementTypeBitWidth(output.get());
}
if (maxBitWidth) {
return maxBitWidth.value();
}

// Fall back, just be a word.
return 32;
}

LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface entryPoint,
Operation *op) {
Expand Down Expand Up @@ -572,6 +611,7 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
bool vectorizable = projPerm && powTwo;

const unsigned minBitwidth = getMinElementBitwidth(linalgOp);
const unsigned representativeBitWidth = getRepresentativeBitWidth(linalgOp);
// Make sure we use a tile size that results in some integral number of bytes.
const unsigned scaleToByte =
std::max(8 / minBitwidth, static_cast<unsigned>(1));
Expand All @@ -580,7 +620,7 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
auto distributeToThreads = [&](int64_t numThreads,
std::optional<int64_t> lossFactor =
std::nullopt) {
LDBG("Loss factor: " << lossFactor << "\n");
LDBG("Loss factor: " << lossFactor);
// Initialize the configuration.
flatWorkgroupSize = 1;
// Initialize thread tiling along all partitioned loops with size 1, and
Expand All @@ -607,13 +647,23 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,

// Ensure vectorization works with the `workgroupTileMultiple`.
int64_t workgroupTileMultiple = workgroupTileSizeMultiples[shapeDim];
unsigned numVectorElements = std::max(4u, 128 / representativeBitWidth);
int64_t vecorizableCandidate = numVectorElements * numThreads;
// For smaller shapes, we reduce `numVectorElements` as we may not find
// work for all threads otherwise and we dont have vectorization enabled
// with loss.
while (vectorizable && (vecorizableCandidate > loopBound) &&
numVectorElements > 4) {
numVectorElements /= 2;
vecorizableCandidate = numVectorElements * numThreads;
}
vectorizable =
vectorizable && 4 * numThreads % workgroupTileMultiple == 0;
// For the inner most workgroup dim, try to see if we can have 4
// elements per thread. This enables vectorization.
vectorizable && vecorizableCandidate % workgroupTileMultiple == 0;

if (vectorizable && wgDim == 0 && !lossFactor) {
candidates.push_back(4 * numThreads);
candidates.push_back(vecorizableCandidate);
}

// Try all power of two multiples of `workgroupTileMultiple` up to the
// subgroup size.
uint64_t maxCandidate =
Expand Down Expand Up @@ -645,17 +695,17 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
llvm::divideCeil(loopBound, scaledTileSize) <= 2) {
continue;
}

// Try to let each thread handle 4 elements if this is the workgroup x
// dimension.
// TODO: Try to take into account element type bit width to get
// 4xdword reads instead of 4x{elements}.
if (vectorizable && wgDim == 0 && !lossFactor && candidate % 4 == 0) {
if (vectorizable && wgDim == 0 && !lossFactor &&
candidate % numVectorElements == 0) {
// Use size-1 vectors to increase parallelism if larger ones causes
// idle threads in the subgroup.
bool hasIdleThreads =
partitionableLoops.size() == 1 && candidate <= subgroupSize;
int vectorSize = hasIdleThreads ? 1 : 4;
int vectorSize = hasIdleThreads ? 1 : numVectorElements;
LLVM_DEBUG(llvm::dbgs() << "Use vector size: " << vectorSize << "\n");
threadTileSizes[shapeDim] = vectorSize * scaleToByte;
candidateWorkgroupSize = candidate / vectorSize;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,25 @@ module {

// -----

module {
func.func @elementwise_dynamic_dim_large(%11: tensor<?x512xf16>, %12: tensor<?x512xf16>) -> tensor<?x512xf16> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%8 = tensor.dim %11, %c0 : tensor<?x512xf16>
%13 = tensor.empty(%8) : tensor<?x512xf16>
%15 = linalg.add ins(%11, %12 : tensor<?x512xf16>, tensor<?x512xf16>) outs(%13 : tensor<?x512xf16>) -> tensor<?x512xf16>
return %15 : tensor<?x512xf16>
}
}

// CHECK-LABEL: func.func @elementwise_dynamic_dim_large
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
// CHECK: linalg.add {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: thread = [1, 8]
// CHECK-SAME: workgroup = [1, 512]

// -----

module @elementwise_unaligned {
func.func @elementwise_unaligned(%11: tensor<180x180xf16>, %12: tensor<180x180xf16>) -> tensor<180x180xf16> {
%cst = arith.constant 0.000000e+00 : f32
Expand Down

0 comments on commit 97c9fcf

Please sign in to comment.