Skip to content

Commit

Permalink
[LLVMGPU] Correct the workgroup level tile sizes for WarpReduction (#…
Browse files Browse the repository at this point in the history
…19819)

In case of Matvec with no M dim, the workgroup level tile sizes was set
to 1, which is suboptimal. Now, the `isMatVecLike` fn also includes
cases with no M dim.
  • Loading branch information
pashu123 authored Jan 28, 2025
1 parent 4b0ca34 commit 103d631
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 24 deletions.
50 changes: 26 additions & 24 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1770,12 +1770,24 @@ static LogicalResult setRootDefaultConfig(IREE::GPU::TargetAttr target,
preferredSubgroupSize);
}

/// Returns true if it's MatVec like i.e., either the bound of M or N dim = 1,
/// or one of M, N dim isn't present.
static bool isMatvecLike(linalg::LinalgOp linalgOp) {
if (linalgOp.getNumParallelLoops() != 2)
return false;

if (linalgOp.getNumReductionLoops() != 1)
SmallVector<int64_t> bounds = linalgOp.getStaticLoopRanges();
SmallVector<unsigned> parallelDims;
linalgOp.getParallelDims(parallelDims);

// Validate that there's exactly one parallel dimension with size != 1.
unsigned nonUnitParallelDimsCount = llvm::count_if(
parallelDims, [&bounds](unsigned idx) { return bounds[idx] != 1; });

// No. of parallel dims size shouldn't exceed 2.
// There should be exactly one reduction loop.
if (parallelDims.size() > 2 || nonUnitParallelDimsCount != 1 ||
linalgOp.getNumReductionLoops() != 1) {
return false;
}

// TODO: Allow for matvec with fused dequantization.
FailureOr<linalg::ContractionDimensions> dims =
Expand All @@ -1787,16 +1799,10 @@ static bool isMatvecLike(linalg::LinalgOp linalgOp) {
if (!dims->batch.empty())
return false;

for (ArrayRef indices : {dims->m, dims->n, dims->k}) {
if (!llvm::hasSingleElement(indices))
return false;
}

// Check if the first parallel dimension has bound 1, indicating we found a
// vector shape.
SmallVector<int64_t, 4> bounds = linalgOp.getStaticLoopRanges();
if (bounds[dims->m.front()] != 1)
if (dims->m.size() >= 2 || dims->n.size() >= 2 ||
!llvm::hasSingleElement(dims->k)) {
return false;
}

return true;
}
Expand Down Expand Up @@ -1868,12 +1874,7 @@ setWarpReductionConfig(IREE::GPU::TargetAttr target,
if (!foundSingleReductionOutput)
return failure();

// Tile all the parallel dimension to 1.
SmallVector<unsigned> partitionedLoops =
cast<PartitionableLoopsInterface>(op.getOperation())
.getPartitionableLoops(kNumMaxParallelDims);
size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1;
SmallVector<int64_t> workgroupTileSizes(numLoops, 1);
SmallVector<int64_t> workgroupTileSizes(op.getNumParallelLoops(), 1);

// Without any bounds on dynamic dims, we need specialization to
// get peak performance. For now, just use the warp size.
Expand Down Expand Up @@ -1978,17 +1979,18 @@ setWarpReductionConfig(IREE::GPU::TargetAttr target,
// validate this strategy and extend to more linalg generics and to CUDA.
if (isROCmBackend(target) && llvm::none_of(bounds, ShapedType::isDynamic) &&
isMatvecLike(op)) {
int64_t lastParallelBound = bounds[parallelDims.back()];
int64_t parallelIdx = *llvm::find_if(
parallelDims, [&](int64_t currIdx) { return bounds[currIdx] != 1; });
int64_t parallelBound = bounds[parallelIdx];
int64_t numParallelReductions = 1;
const int64_t maxParallelFactor = groupSize / 4;
for (int64_t parallelFactor = 2;
(parallelFactor < maxParallelFactor) &&
(lastParallelBound % parallelFactor == 0) &&
(lastParallelBound > parallelFactor);
for (int64_t parallelFactor = 2; (parallelFactor < maxParallelFactor) &&
(parallelBound % parallelFactor == 0) &&
(parallelBound > parallelFactor);
parallelFactor *= 2) {
numParallelReductions = parallelFactor;
}
workgroupTileSizes.back() = numParallelReductions;
workgroupTileSizes[parallelIdx] = numParallelReductions;
}

std::array<int64_t, 3> workgroupSize = {groupSize, 1, 1};
Expand Down
76 changes: 76 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,82 @@ func.func @vmt1() attributes {hal.executable.target = #executable_target_rocm_hs

// -----

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb">
#map = affine_map<(d0, d1) -> (d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> (d0)>
func.func @matvec_like_no_m_dim() attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096xf16>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32000x4096xf16>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<32000xf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [4096], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4096xf16>> -> tensor<4096xf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [32000, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32000x4096xf16>> -> tensor<32000x4096xf16>
%5 = tensor.empty() : tensor<32000xf16>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<32000xf16>) -> tensor<32000xf16>
%7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction"]} ins(%3, %4 : tensor<4096xf16>, tensor<32000x4096xf16>) outs(%6 : tensor<32000xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%8 = arith.mulf %in, %in_0 : f16
%9 = arith.addf %out, %8 : f16
linalg.yield %9 : f16
} -> tensor<32000xf16>
flow.dispatch.tensor.store %7, %2, offsets = [0], sizes = [32000], strides = [1] : tensor<32000xf16> -> !flow.dispatch.tensor<writeonly:tensor<32000xf16>>
return
}

// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[8], [0, 512]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [64, 1, 1] subgroup_size = 64>
// CHECK-LABEL: func.func @matvec_like_no_m_dim()
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[$CONFIG]]

// -----

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb">
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
func.func @matvec_unit_n_dim() attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x4096xf16>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32000x4096xf16>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<32000x1xf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x4096xf16>> -> tensor<1x4096xf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [32000, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32000x4096xf16>> -> tensor<32000x4096xf16>
%5 = tensor.empty() : tensor<32000x1xf16>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<32000x1xf16>) -> tensor<32000x1xf16>
%7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%4, %3 : tensor<32000x4096xf16>, tensor<1x4096xf16>) outs(%6 : tensor<32000x1xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%8 = arith.mulf %in, %in_0 : f16
%9 = arith.addf %out, %8 : f16
linalg.yield %9 : f16
} -> tensor<32000x1xf16>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [32000, 1], strides = [1, 1] : tensor<32000x1xf16> -> !flow.dispatch.tensor<writeonly:tensor<32000x1xf16>>
return
}

// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[8, 1], [0, 0, 512]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [64, 1, 1] subgroup_size = 64>
// CHECK-LABEL: func.func @matvec_unit_n_dim()
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[$CONFIG]]

// -----

// This test uses special heuristics that needs to check the backend in the #hal.executable.target.

#pipeline_layout = #hal.pipeline.layout<bindings = [
Expand Down

0 comments on commit 103d631

Please sign in to comment.