Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Take element type bitwidth into account with vector size. #19987

Merged
merged 4 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,48 @@ static bool isNonMatvecContraction(linalg::LinalgOp linalgOp) {
getElementCount(contractionDims->n) != 1;
}

// To find the number of vector elements per work-item, find a
// bit width that is representative of the computation.
static unsigned getRepresentativeBitWidth(linalg::LinalgOp linalgOp) {
// Check all the inputs with permutation indexing maps. Use
// the maximum of those to get the bit width.
std::optional<unsigned> maxBitWidth;
auto updateElementTypeBitWidth = [&](Value v) {
auto elementType = getElementTypeOrSelf(v);
if (!elementType.isIntOrFloat()) {
return;
}
unsigned bitWidth = elementType.getIntOrFloatBitWidth();
if (maxBitWidth) {
maxBitWidth = std::max(maxBitWidth.value(), bitWidth);
return;
}
maxBitWidth = bitWidth;
};
for (OpOperand *input : linalgOp.getDpsInputOperands()) {
AffineMap inputOperandMap = linalgOp.getMatchingIndexingMap(input);
if (!inputOperandMap.isPermutation()) {
continue;
}
updateElementTypeBitWidth(input->get());
}
if (maxBitWidth) {
return maxBitWidth.value();
}

// If none of the operands have permutation inputs, use the result.
// Dont bother about the indexing map.
for (OpOperand &output : linalgOp.getDpsInitsMutable()) {
updateElementTypeBitWidth(output.get());
}
if (maxBitWidth) {
return maxBitWidth.value();
}

// Fall back, just be a word.
return 32;
}

LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface entryPoint,
Operation *op) {
Expand Down Expand Up @@ -572,6 +614,7 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
bool vectorizable = projPerm && powTwo;

const unsigned minBitwidth = getMinElementBitwidth(linalgOp);
const unsigned representativeBitWidth = getRepresentativeBitWidth(linalgOp);
// Make sure we use a tile size that results in some integral number of bytes.
const unsigned scaleToByte =
std::max(8 / minBitwidth, static_cast<unsigned>(1));
Expand All @@ -580,7 +623,7 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
auto distributeToThreads = [&](int64_t numThreads,
std::optional<int64_t> lossFactor =
std::nullopt) {
LDBG("Loss factor: " << lossFactor << "\n");
LDBG("Loss factor: " << lossFactor);
// Initialize the configuration.
flatWorkgroupSize = 1;
// Initialize thread tiling along all partitioned loops with size 1, and
Expand All @@ -607,13 +650,23 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,

// Ensure vectorization works with the `workgroupTileMultiple`.
int64_t workgroupTileMultiple = workgroupTileSizeMultiples[shapeDim];
unsigned numVectorElements = std::max(4u, 128 / representativeBitWidth);
int64_t vectorizableCandidate = numVectorElements * numThreads;
// For smaller shapes, we reduce `numVectorElements` as we may not find
// work for all threads otherwise and we dont have vectorization enabled
// with loss.
while (vectorizable && (vectorizableCandidate > loopBound) &&
numVectorElements > 4) {
numVectorElements /= 2;
vectorizableCandidate = numVectorElements * numThreads;
}
vectorizable =
vectorizable && 4 * numThreads % workgroupTileMultiple == 0;
// For the inner most workgroup dim, try to see if we can have 4
// elements per thread. This enables vectorization.
vectorizable && vectorizableCandidate % workgroupTileMultiple == 0;

if (vectorizable && wgDim == 0 && !lossFactor) {
candidates.push_back(4 * numThreads);
candidates.push_back(vectorizableCandidate);
}

// Try all power of two multiples of `workgroupTileMultiple` up to the
// subgroup size.
uint64_t maxCandidate =
Expand Down Expand Up @@ -645,17 +698,17 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
llvm::divideCeil(loopBound, scaledTileSize) <= 2) {
continue;
}

// Try to let each thread handle 4 elements if this is the workgroup x
// dimension.
// TODO: Try to take into account element type bit width to get
// 4xdword reads instead of 4x{elements}.
if (vectorizable && wgDim == 0 && !lossFactor && candidate % 4 == 0) {
if (vectorizable && wgDim == 0 && !lossFactor &&
candidate % numVectorElements == 0) {
// Use size-1 vectors to increase parallelism if larger ones causes
// idle threads in the subgroup.
bool hasIdleThreads =
partitionableLoops.size() == 1 && candidate <= subgroupSize;
int vectorSize = hasIdleThreads ? 1 : 4;
int vectorSize = hasIdleThreads ? 1 : numVectorElements;
LLVM_DEBUG(llvm::dbgs() << "Use vector size: " << vectorSize << "\n");
threadTileSizes[shapeDim] = vectorSize * scaleToByte;
candidateWorkgroupSize = candidate / vectorSize;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,3 +654,25 @@ func.func @pack_dynamic_tile(%arg0: tensor<32x32xi8>, %d0: index, %d1: index, %t
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: thread = [1, 4]
// CHECK-SAME: workgroup = [8, 32]

// -----

module {
func.func @erf(%13 : tensor<2x1024x5120xf16>, %12 : tensor<2x1024x5120xf16>, %9 : tensor<5120xf16>, %10 : tensor<f32>) -> tensor<2x1024x5120xi8> {
%cst = arith.constant 0.000000e+00 : f16
%11 = tensor.empty() : tensor<2x1024x5120xi8>
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%13, %12, %9, %10 : tensor<2x1024x5120xf16>, tensor<2x1024x5120xf16>, tensor<5120xf16>, tensor<f32>) outs(%11 : tensor<2x1024x5120xi8>) {
^bb0(%in: f16, %in_4: f16, %in_5: f16, %in_6: f32, %out: i8):
%17 = math.erf %in : f16
%30 = arith.fptosi %17 : f16 to i8
linalg.yield %30 : i8
} -> tensor<2x1024x5120xi8>
return %14 : tensor<2x1024x5120xi8>
}
}

// CHECK-LABEL: func.func @erf
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: thread = [1, 1, 8]
// CHECK-SAME: workgroup = [1, 1, 512]
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ hal.executable @ext_fp8_dispatch {

// CDNA3-LABEL: hal.executable public @ext_fp8_dispatch
// CDNA3: hal.executable.variant public @rocm
// CDNA3-COUNT-4: rocdl.cvt.f32.fp8 %{{.*}} : f32
// CDNA3-COUNT-4: rocdl.cvt.f32.bf8 %{{.*}} : f32
// CDNA3: %[[ADD:.+]] = llvm.fadd %{{.*}}, %{{.*}} : vector<4xf32>
// CDNA3: llvm.store %[[ADD]], %{{.*}} : vector<4xf32>, !llvm.ptr<1>
// CDNA3-COUNT-16: rocdl.cvt.f32.fp8 %{{.*}} : f32
// CDNA3-COUNT-16: rocdl.cvt.f32.bf8 %{{.*}} : f32
// CDNA3: %[[ADD:.+]] = llvm.fadd %{{.*}}, %{{.*}} : vector<16xf32>
// CDNA3: llvm.store %[[ADD]], %{{.*}} : vector<16xf32>, !llvm.ptr<1>

// -----

Expand Down
Loading