Skip to content

Commit

Permalink
[GPU] Take element type bitwidth into account with vector size. (#19987)
Browse files Browse the repository at this point in the history
This is for TileAndFuse workgroup and thread distribution logic. Before
this PR we only tried to vectorize to size 4 but now we vectorize to a
bitWidth of 128. This can however result in no vectorization for small
dimensions so the PR also introduces a fallback for smaller shapes to
vector size of 4.

Co-authored-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: Nirvedh <[email protected]>

---------

Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: Nirvedh <[email protected]>
Co-authored-by: MaheshRavishankar <[email protected]>
  • Loading branch information
nirvedhmeshram and MaheshRavishankar authored Feb 14, 2025
1 parent e2f3565 commit cdba184
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,48 @@ static bool isNonMatvecContraction(linalg::LinalgOp linalgOp) {
getElementCount(contractionDims->n) != 1;
}

// To find the number of vector elements per work-item, find a
// bit width that is representative of the computation.
static unsigned getRepresentativeBitWidth(linalg::LinalgOp linalgOp) {
// Check all the inputs with permutation indexing maps. Use
// the maximum of those to get the bit width.
std::optional<unsigned> maxBitWidth;
auto updateElementTypeBitWidth = [&](Value v) {
auto elementType = getElementTypeOrSelf(v);
if (!elementType.isIntOrFloat()) {
return;
}
unsigned bitWidth = elementType.getIntOrFloatBitWidth();
if (maxBitWidth) {
maxBitWidth = std::max(maxBitWidth.value(), bitWidth);
return;
}
maxBitWidth = bitWidth;
};
for (OpOperand *input : linalgOp.getDpsInputOperands()) {
AffineMap inputOperandMap = linalgOp.getMatchingIndexingMap(input);
if (!inputOperandMap.isPermutation()) {
continue;
}
updateElementTypeBitWidth(input->get());
}
if (maxBitWidth) {
return maxBitWidth.value();
}

// If none of the operands have permutation inputs, use the result.
// Dont bother about the indexing map.
for (OpOperand &output : linalgOp.getDpsInitsMutable()) {
updateElementTypeBitWidth(output.get());
}
if (maxBitWidth) {
return maxBitWidth.value();
}

// Fall back, just be a word.
return 32;
}

LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface entryPoint,
Operation *op) {
Expand Down Expand Up @@ -572,6 +614,7 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
bool vectorizable = projPerm && powTwo;

const unsigned minBitwidth = getMinElementBitwidth(linalgOp);
const unsigned representativeBitWidth = getRepresentativeBitWidth(linalgOp);
// Make sure we use a tile size that results in some integral number of bytes.
const unsigned scaleToByte =
std::max(8 / minBitwidth, static_cast<unsigned>(1));
Expand All @@ -580,7 +623,7 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
auto distributeToThreads = [&](int64_t numThreads,
std::optional<int64_t> lossFactor =
std::nullopt) {
LDBG("Loss factor: " << lossFactor << "\n");
LDBG("Loss factor: " << lossFactor);
// Initialize the configuration.
flatWorkgroupSize = 1;
// Initialize thread tiling along all partitioned loops with size 1, and
Expand All @@ -607,13 +650,23 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,

// Ensure vectorization works with the `workgroupTileMultiple`.
int64_t workgroupTileMultiple = workgroupTileSizeMultiples[shapeDim];
unsigned numVectorElements = std::max(4u, 128 / representativeBitWidth);
int64_t vectorizableCandidate = numVectorElements * numThreads;
// For smaller shapes, we reduce `numVectorElements` as we may not find
// work for all threads otherwise and we dont have vectorization enabled
// with loss.
while (vectorizable && (vectorizableCandidate > loopBound) &&
numVectorElements > 4) {
numVectorElements /= 2;
vectorizableCandidate = numVectorElements * numThreads;
}
vectorizable =
vectorizable && 4 * numThreads % workgroupTileMultiple == 0;
// For the inner most workgroup dim, try to see if we can have 4
// elements per thread. This enables vectorization.
vectorizable && vectorizableCandidate % workgroupTileMultiple == 0;

if (vectorizable && wgDim == 0 && !lossFactor) {
candidates.push_back(4 * numThreads);
candidates.push_back(vectorizableCandidate);
}

// Try all power of two multiples of `workgroupTileMultiple` up to the
// subgroup size.
uint64_t maxCandidate =
Expand Down Expand Up @@ -645,17 +698,17 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
llvm::divideCeil(loopBound, scaledTileSize) <= 2) {
continue;
}

// Try to let each thread handle 4 elements if this is the workgroup x
// dimension.
// TODO: Try to take into account element type bit width to get
// 4xdword reads instead of 4x{elements}.
if (vectorizable && wgDim == 0 && !lossFactor && candidate % 4 == 0) {
if (vectorizable && wgDim == 0 && !lossFactor &&
candidate % numVectorElements == 0) {
// Use size-1 vectors to increase parallelism if larger ones causes
// idle threads in the subgroup.
bool hasIdleThreads =
partitionableLoops.size() == 1 && candidate <= subgroupSize;
int vectorSize = hasIdleThreads ? 1 : 4;
int vectorSize = hasIdleThreads ? 1 : numVectorElements;
LLVM_DEBUG(llvm::dbgs() << "Use vector size: " << vectorSize << "\n");
threadTileSizes[shapeDim] = vectorSize * scaleToByte;
candidateWorkgroupSize = candidate / vectorSize;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,3 +654,25 @@ func.func @pack_dynamic_tile(%arg0: tensor<32x32xi8>, %d0: index, %d1: index, %t
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: thread = [1, 4]
// CHECK-SAME: workgroup = [8, 32]

// -----

module {
func.func @erf(%13 : tensor<2x1024x5120xf16>, %12 : tensor<2x1024x5120xf16>, %9 : tensor<5120xf16>, %10 : tensor<f32>) -> tensor<2x1024x5120xi8> {
%cst = arith.constant 0.000000e+00 : f16
%11 = tensor.empty() : tensor<2x1024x5120xi8>
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%13, %12, %9, %10 : tensor<2x1024x5120xf16>, tensor<2x1024x5120xf16>, tensor<5120xf16>, tensor<f32>) outs(%11 : tensor<2x1024x5120xi8>) {
^bb0(%in: f16, %in_4: f16, %in_5: f16, %in_6: f32, %out: i8):
%17 = math.erf %in : f16
%30 = arith.fptosi %17 : f16 to i8
linalg.yield %30 : i8
} -> tensor<2x1024x5120xi8>
return %14 : tensor<2x1024x5120xi8>
}
}

// CHECK-LABEL: func.func @erf
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: thread = [1, 1, 8]
// CHECK-SAME: workgroup = [1, 1, 512]
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ hal.executable @ext_fp8_dispatch {

// CDNA3-LABEL: hal.executable public @ext_fp8_dispatch
// CDNA3: hal.executable.variant public @rocm
// CDNA3-COUNT-4: rocdl.cvt.f32.fp8 %{{.*}} : f32
// CDNA3-COUNT-4: rocdl.cvt.f32.bf8 %{{.*}} : f32
// CDNA3: %[[ADD:.+]] = llvm.fadd %{{.*}}, %{{.*}} : vector<4xf32>
// CDNA3: llvm.store %[[ADD]], %{{.*}} : vector<4xf32>, !llvm.ptr<1>
// CDNA3-COUNT-16: rocdl.cvt.f32.fp8 %{{.*}} : f32
// CDNA3-COUNT-16: rocdl.cvt.f32.bf8 %{{.*}} : f32
// CDNA3: %[[ADD:.+]] = llvm.fadd %{{.*}}, %{{.*}} : vector<16xf32>
// CDNA3: llvm.store %[[ADD]], %{{.*}} : vector<16xf32>, !llvm.ptr<1>

// -----

Expand Down

0 comments on commit cdba184

Please sign in to comment.