Skip to content

Commit

Permalink
[CombineStridedOps] Generalize dimension checking
Browse files Browse the repository at this point in the history
  • Loading branch information
jtuyls committed Jan 15, 2025
1 parent 42fa1e9 commit 75060e4
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 61 deletions.
41 changes: 34 additions & 7 deletions build_tools/ci/cpu_comparison/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,17 @@ class BatchMatmul(BaseMatmul):
A test of the form batch_matmul(A,B) where A:BxMxK, B:BxKxN
"""

def __init__(self, B, M, N, K, input_type, acc_type, run_on_target=["npu1_4col"]):
def __init__(
self,
B,
M,
N,
K,
input_type,
acc_type,
run_on_target=["npu1_4col"],
tile_pipeline="pack-peel",
):
super().__init__(
run_on_target=run_on_target,
aie_compilation_flags=None,
Expand All @@ -717,12 +727,14 @@ def __init__(self, B, M, N, K, input_type, acc_type, run_on_target=["npu1_4col"]
K=K,
input_type=input_type,
acc_type=acc_type,
tile_pipeline="pack-peel",
tile_pipeline=tile_pipeline,
n_repeats=1,
)
self.labels.append("BatchMatmul")

self.name = f"batch_matmul_{B}_{M}_{N}_{K}_{input_type}_{acc_type}"
if tile_pipeline == "pack-peel-4-level-tiling":
self.name += "_4_level_tiling"
self.B = B

def _execute(self, config):
Expand Down Expand Up @@ -1624,11 +1636,26 @@ def __init__(self):
)

# BatchMatmul test(s):
for input_type, acc_type in zip(["i32", "bf16"], ["i32", "f32"]):
# Batch size = 1:
self.register(BatchMatmul(1, 128, 128, 256, input_type, acc_type))
# Batch size = 2:
self.register(BatchMatmul(2, 64, 64, 64, input_type, acc_type))
for tile_pipeline in ["pack-peel", "pack-peel-4-level-tiling"]:
for input_type, acc_type in zip(["i32", "bf16"], ["i32", "f32"]):
# Batch size = 1:
self.register(
BatchMatmul(
1,
128,
128,
256,
input_type,
acc_type,
tile_pipeline=tile_pipeline,
)
)
# Batch size = 2:
self.register(
BatchMatmul(
2, 64, 64, 64, input_type, acc_type, tile_pipeline=tile_pipeline
)
)

# MatmulThinBias test(s):
self.register(MatmulThinBias(1024, 1024, 512, "bf16", "f32", use_ukernel=True))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

#define DEBUG_TYPE "iree-amdaie-combine-strided-ops"

using namespace std::placeholders;

namespace mlir::iree_compiler::AMDAIE {

namespace {
Expand All @@ -43,6 +45,8 @@ struct CombineStridedOps
Block *block = op->getBlock();
if (!block) return failure();

std::unique_ptr<DmaDimConfig> sourceDmaDimConfig;
std::unique_ptr<DmaDimConfig> targetDmaDimConfig;
SmallVector<Operation *> userOpsToBeErased;
AMDAIE::DoublyStridedOpInterface nextStridedOp;
if (auto npuDmaOp = dyn_cast<AMDAIE::NpuDmaCpyNdOp>(op.getOperation())) {
Expand All @@ -60,6 +64,20 @@ struct CombineStridedOps
if (failed(maybeNextNpuDmaOp)) return failure();
nextStridedOp = cast<AMDAIE::DoublyStridedOpInterface>(
maybeNextNpuDmaOp->getOperation());
if (!nextStridedOp) return failure();

std::optional<uint8_t> sourceMemspaceInt =
nextStridedOp.getSourceMemorySpaceAsUInt();
std::optional<uint8_t> targetMemspaceInt =
nextStridedOp.getTargetMemorySpaceAsUInt();
if (!sourceMemspaceInt || !targetMemspaceInt) {
return rewriter.notifyMatchFailure(
nextStridedOp, "expected a source and target memory space");
}
sourceDmaDimConfig = std::make_unique<DmaDimConfig>(
deviceModel, sourceMemspaceInt.value());
targetDmaDimConfig = std::make_unique<DmaDimConfig>(
deviceModel, targetMemspaceInt.value());
} else if (auto npuCircularDmaOp =
dyn_cast<AMDAIE::NpuCircularDmaCpyNdOp>(op.getOperation())) {
LLVM_DEBUG(llvm::dbgs()
Expand All @@ -69,25 +87,24 @@ struct CombineStridedOps
if (failed(maybeNextNpuCircDmaOp)) return failure();
nextStridedOp = cast<AMDAIE::DoublyStridedOpInterface>(
maybeNextNpuCircDmaOp->getOperation());
if (!nextStridedOp) return failure();

std::optional<uint8_t> sourceMemspaceInt =
nextStridedOp.getSourceMemorySpaceAsUInt();
std::optional<uint8_t> targetMemspaceInt =
nextStridedOp.getTargetMemorySpaceAsUInt();
if (!sourceMemspaceInt || !targetMemspaceInt) {
return rewriter.notifyMatchFailure(
nextStridedOp, "expected a source and target memory space");
}
sourceDmaDimConfig = std::make_unique<CircularDmaDimConfig>(
deviceModel, sourceMemspaceInt.value());
targetDmaDimConfig = std::make_unique<CircularDmaDimConfig>(
deviceModel, targetMemspaceInt.value());
} else {
return failure();
}

if (!nextStridedOp) return failure();

std::optional<uint8_t> sourceMemspaceInt =
nextStridedOp.getSourceMemorySpaceAsUInt();
std::optional<uint8_t> targetMemspaceInt =
nextStridedOp.getTargetMemorySpaceAsUInt();
if (!sourceMemspaceInt || !targetMemspaceInt) {
return rewriter.notifyMatchFailure(
nextStridedOp, "expected a source and target memory space");
}
DmaDimConfig sourceDmaDimConfig(deviceModel, sourceMemspaceInt.value());
size_t sourceMaxNbDims = sourceDmaDimConfig.maxNbDims;
DmaDimConfig targetDmaDimConfig(deviceModel, targetMemspaceInt.value());
size_t targetMaxNbDims = targetDmaDimConfig.maxNbDims;

SmallVector<OpFoldResult> sourceOffsetsA = op.getSourceMixedOffsets();
SmallVector<OpFoldResult> sourceSizesA = op.getSourceMixedSizes();
SmallVector<OpFoldResult> sourceStridesA = op.getSourceMixedStrides();
Expand All @@ -99,7 +116,9 @@ struct CombineStridedOps
nextStridedOp.getSourceMixedStrides();
bool areSourcesCombinable = areAccessPatternsCombinable(
sourceOffsetsA, sourceSizesA, sourceStridesA, sourceOffsetsB,
sourceSizesB, sourceStridesB, sourceMaxNbDims);
sourceSizesB, sourceStridesB,
std::bind(&DmaDimConfig::exceedsNbDims, std::ref(sourceDmaDimConfig),
_1));

SmallVector<OpFoldResult> targetOffsetsA = op.getTargetMixedOffsets();
SmallVector<OpFoldResult> targetSizesA = op.getTargetMixedSizes();
Expand All @@ -112,7 +131,14 @@ struct CombineStridedOps
nextStridedOp.getTargetMixedStrides();
bool areTargetsCombinable = areAccessPatternsCombinable(
targetOffsetsA, targetSizesA, targetStridesA, targetOffsetsB,
targetSizesB, targetStridesB, targetMaxNbDims);
targetSizesB, targetStridesB,
std::bind(&DmaDimConfig::exceedsNbDims, std::ref(targetDmaDimConfig),
_1));

LLVM_DEBUG(llvm::dbgs()
<< "areSourcesCombinable: " << areSourcesCombinable << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "areTargetsCombinable: " << areTargetsCombinable << "\n");

if (areSourcesCombinable && areTargetsCombinable) {
SmallVector<OpFoldResult> newSourceOffsets;
Expand All @@ -121,7 +147,9 @@ struct CombineStridedOps
if (failed(combineAccessPatterns(
rewriter, sourceOffsetsA, sourceSizesA, sourceStridesA,
sourceOffsetsB, sourceSizesB, sourceStridesB, newSourceOffsets,
newSourceSizes, newSourceStrides, sourceMaxNbDims))) {
newSourceSizes, newSourceStrides,
std::bind(&DmaDimConfig::exceedsNbDims,
std::ref(sourceDmaDimConfig), _1)))) {
return failure();
}

Expand All @@ -131,7 +159,9 @@ struct CombineStridedOps
if (failed(combineAccessPatterns(
rewriter, targetOffsetsA, targetSizesA, targetStridesA,
targetOffsetsB, targetSizesB, targetStridesB, newTargetOffsets,
newTargetSizes, newTargetStrides, targetMaxNbDims))) {
newTargetSizes, newTargetStrides,
std::bind(&DmaDimConfig::exceedsNbDims,
std::ref(targetDmaDimConfig), _1)))) {
return failure();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"

#define DEBUG_TYPE "iree-amdaie-dma-utils"

namespace mlir::iree_compiler::AMDAIE {

static bool isEqualConstantIntOrValueArrayFromIndices(
Expand Down Expand Up @@ -42,7 +44,7 @@ bool areAccessPatternsCombinable(const SmallVector<OpFoldResult> &offsetsA,
const SmallVector<OpFoldResult> &offsetsB,
const SmallVector<OpFoldResult> &sizesB,
const SmallVector<OpFoldResult> &stridesB,
size_t maxNbDims) {
function_ref<bool(size_t)> exceedsNbDims) {
assert(offsetsA.size() == sizesA.size() &&
"expected same number of source offsets and sizes");
assert(offsetsA.size() == stridesA.size() &&
Expand All @@ -59,8 +61,11 @@ bool areAccessPatternsCombinable(const SmallVector<OpFoldResult> &offsetsA,
// In case both access patterns have the same number of dimension, a new
// dimension will need to be added, so fail if there aren't enough
// dimensions.
if (offsetsA.size() == offsetsB.size() && offsetsA.size() + 1 > maxNbDims)
if (offsetsA.size() == offsetsB.size() &&
exceedsNbDims(offsetsA.size() + 1)) {
LLVM_DEBUG(llvm::dbgs() << "Exceeded maximum number of dimensions\n");
return false;
}

// Equality of the last N elements of the access patterns of A and B with N =
// min(sizeA, sizeB) results in some simple cases in which the access
Expand Down Expand Up @@ -196,7 +201,7 @@ LogicalResult combineAccessPatterns(RewriterBase &rewriter,
SmallVector<OpFoldResult> &newOffsets,
SmallVector<OpFoldResult> &newSizes,
SmallVector<OpFoldResult> &newStrides,
size_t maxNbDims) {
function_ref<bool(size_t)> exceedsNbDims) {
assert(offsetsA.size() == sizesA.size() &&
"expected same number of source offsets and sizes");
assert(offsetsA.size() == stridesA.size() &&
Expand All @@ -206,7 +211,7 @@ LogicalResult combineAccessPatterns(RewriterBase &rewriter,
assert(offsetsB.size() == stridesB.size() &&
"expected same number of source offsets and strides");
if (!areAccessPatternsCombinable(offsetsA, sizesA, stridesA, offsetsB, sizesB,
stridesB, maxNbDims)) {
stridesB, exceedsNbDims)) {
return failure();
}
if (offsetsA.empty() && offsetsB.empty()) return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ bool areAccessPatternsCombinable(const SmallVector<OpFoldResult> &offsetsA,
const SmallVector<OpFoldResult> &offsetsB,
const SmallVector<OpFoldResult> &sizesB,
const SmallVector<OpFoldResult> &stridesB,
size_t maxNbDims);
function_ref<bool(size_t)> exceedsNbDims);

/// Combine two access patterns into a single one. Assumes that access pattern A
/// belongs to a strided op which is ordered before the strided op B. Takes a
Expand All @@ -110,7 +110,7 @@ LogicalResult combineAccessPatterns(RewriterBase &rewriter,
SmallVector<OpFoldResult> &newOffsets,
SmallVector<OpFoldResult> &newSizes,
SmallVector<OpFoldResult> &newStrides,
size_t maxNbDims);
function_ref<bool(size_t)> exceedsNbDims);

/// Fold subsequent dimensions within a strided access pattern that describe a
/// single linear access. Returns `success` if folding took place.
Expand Down
Loading

0 comments on commit 75060e4

Please sign in to comment.