Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CombineStridedOps] Generalize dimension checking #1032

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions build_tools/ci/cpu_comparison/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,17 @@ class BatchMatmul(BaseMatmul):
A test of the form batch_matmul(A,B) where A:BxMxK, B:BxKxN
"""

def __init__(self, B, M, N, K, input_type, acc_type, run_on_target=["npu1_4col"]):
def __init__(
self,
B,
M,
N,
K,
input_type,
acc_type,
run_on_target=["npu1_4col"],
tile_pipeline="pack-peel",
):
super().__init__(
run_on_target=run_on_target,
aie_compilation_flags=None,
Expand All @@ -717,12 +727,14 @@ def __init__(self, B, M, N, K, input_type, acc_type, run_on_target=["npu1_4col"]
K=K,
input_type=input_type,
acc_type=acc_type,
tile_pipeline="pack-peel",
tile_pipeline=tile_pipeline,
n_repeats=1,
)
self.labels.append("BatchMatmul")

self.name = f"batch_matmul_{B}_{M}_{N}_{K}_{input_type}_{acc_type}"
if tile_pipeline == "pack-peel-4-level-tiling":
self.name += "_4_level_tiling"
self.B = B

def _execute(self, config):
Expand Down Expand Up @@ -1624,11 +1636,26 @@ def __init__(self):
)

# BatchMatmul test(s):
for input_type, acc_type in zip(["i32", "bf16"], ["i32", "f32"]):
# Batch size = 1:
self.register(BatchMatmul(1, 128, 128, 256, input_type, acc_type))
# Batch size = 2:
self.register(BatchMatmul(2, 64, 64, 64, input_type, acc_type))
for tile_pipeline in ["pack-peel", "pack-peel-4-level-tiling"]:
for input_type, acc_type in zip(["i32", "bf16"], ["i32", "f32"]):
# Batch size = 1:
self.register(
BatchMatmul(
1,
128,
128,
256,
input_type,
acc_type,
tile_pipeline=tile_pipeline,
)
)
# Batch size = 2:
self.register(
BatchMatmul(
2, 64, 64, 64, input_type, acc_type, tile_pipeline=tile_pipeline
)
)

# MatmulThinBias test(s):
self.register(MatmulThinBias(1024, 1024, 512, "bf16", "f32", use_ukernel=True))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

#define DEBUG_TYPE "iree-amdaie-combine-strided-ops"

using namespace std::placeholders;

namespace mlir::iree_compiler::AMDAIE {

namespace {
Expand All @@ -43,6 +45,8 @@ struct CombineStridedOps
Block *block = op->getBlock();
if (!block) return failure();

std::unique_ptr<DmaDimConfig> sourceDmaDimConfig;
std::unique_ptr<DmaDimConfig> targetDmaDimConfig;
SmallVector<Operation *> userOpsToBeErased;
AMDAIE::DoublyStridedOpInterface nextStridedOp;
if (auto npuDmaOp = dyn_cast<AMDAIE::NpuDmaCpyNdOp>(op.getOperation())) {
Expand All @@ -60,6 +64,20 @@ struct CombineStridedOps
if (failed(maybeNextNpuDmaOp)) return failure();
nextStridedOp = cast<AMDAIE::DoublyStridedOpInterface>(
maybeNextNpuDmaOp->getOperation());
if (!nextStridedOp) return failure();

std::optional<uint8_t> sourceMemspaceInt =
nextStridedOp.getSourceMemorySpaceAsUInt();
std::optional<uint8_t> targetMemspaceInt =
nextStridedOp.getTargetMemorySpaceAsUInt();
if (!sourceMemspaceInt || !targetMemspaceInt) {
return rewriter.notifyMatchFailure(
nextStridedOp, "expected a source and target memory space");
}
sourceDmaDimConfig = std::make_unique<DmaDimConfig>(
deviceModel, sourceMemspaceInt.value());
targetDmaDimConfig = std::make_unique<DmaDimConfig>(
deviceModel, targetMemspaceInt.value());
} else if (auto npuCircularDmaOp =
dyn_cast<AMDAIE::NpuCircularDmaCpyNdOp>(op.getOperation())) {
LLVM_DEBUG(llvm::dbgs()
Expand All @@ -69,25 +87,24 @@ struct CombineStridedOps
if (failed(maybeNextNpuCircDmaOp)) return failure();
nextStridedOp = cast<AMDAIE::DoublyStridedOpInterface>(
maybeNextNpuCircDmaOp->getOperation());
if (!nextStridedOp) return failure();

std::optional<uint8_t> sourceMemspaceInt =
nextStridedOp.getSourceMemorySpaceAsUInt();
std::optional<uint8_t> targetMemspaceInt =
nextStridedOp.getTargetMemorySpaceAsUInt();
if (!sourceMemspaceInt || !targetMemspaceInt) {
return rewriter.notifyMatchFailure(
nextStridedOp, "expected a source and target memory space");
}
sourceDmaDimConfig = std::make_unique<CircularDmaDimConfig>(
deviceModel, sourceMemspaceInt.value());
targetDmaDimConfig = std::make_unique<CircularDmaDimConfig>(
deviceModel, targetMemspaceInt.value());
} else {
return failure();
}

if (!nextStridedOp) return failure();

std::optional<uint8_t> sourceMemspaceInt =
nextStridedOp.getSourceMemorySpaceAsUInt();
std::optional<uint8_t> targetMemspaceInt =
nextStridedOp.getTargetMemorySpaceAsUInt();
if (!sourceMemspaceInt || !targetMemspaceInt) {
return rewriter.notifyMatchFailure(
nextStridedOp, "expected a source and target memory space");
}
DmaDimConfig sourceDmaDimConfig(deviceModel, sourceMemspaceInt.value());
size_t sourceMaxNbDims = sourceDmaDimConfig.maxNbDims;
DmaDimConfig targetDmaDimConfig(deviceModel, targetMemspaceInt.value());
size_t targetMaxNbDims = targetDmaDimConfig.maxNbDims;

SmallVector<OpFoldResult> sourceOffsetsA = op.getSourceMixedOffsets();
SmallVector<OpFoldResult> sourceSizesA = op.getSourceMixedSizes();
SmallVector<OpFoldResult> sourceStridesA = op.getSourceMixedStrides();
Expand All @@ -99,7 +116,9 @@ struct CombineStridedOps
nextStridedOp.getSourceMixedStrides();
bool areSourcesCombinable = areAccessPatternsCombinable(
sourceOffsetsA, sourceSizesA, sourceStridesA, sourceOffsetsB,
sourceSizesB, sourceStridesB, sourceMaxNbDims);
sourceSizesB, sourceStridesB,
std::bind(&DmaDimConfig::exceedsNbDims, std::ref(sourceDmaDimConfig),
_1));

SmallVector<OpFoldResult> targetOffsetsA = op.getTargetMixedOffsets();
SmallVector<OpFoldResult> targetSizesA = op.getTargetMixedSizes();
Expand All @@ -112,7 +131,14 @@ struct CombineStridedOps
nextStridedOp.getTargetMixedStrides();
bool areTargetsCombinable = areAccessPatternsCombinable(
targetOffsetsA, targetSizesA, targetStridesA, targetOffsetsB,
targetSizesB, targetStridesB, targetMaxNbDims);
targetSizesB, targetStridesB,
std::bind(&DmaDimConfig::exceedsNbDims, std::ref(targetDmaDimConfig),
_1));

LLVM_DEBUG(llvm::dbgs()
<< "areSourcesCombinable: " << areSourcesCombinable << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "areTargetsCombinable: " << areTargetsCombinable << "\n");

if (areSourcesCombinable && areTargetsCombinable) {
SmallVector<OpFoldResult> newSourceOffsets;
Expand All @@ -121,7 +147,9 @@ struct CombineStridedOps
if (failed(combineAccessPatterns(
rewriter, sourceOffsetsA, sourceSizesA, sourceStridesA,
sourceOffsetsB, sourceSizesB, sourceStridesB, newSourceOffsets,
newSourceSizes, newSourceStrides, sourceMaxNbDims))) {
newSourceSizes, newSourceStrides,
std::bind(&DmaDimConfig::exceedsNbDims,
std::ref(sourceDmaDimConfig), _1)))) {
return failure();
}

Expand All @@ -131,7 +159,9 @@ struct CombineStridedOps
if (failed(combineAccessPatterns(
rewriter, targetOffsetsA, targetSizesA, targetStridesA,
targetOffsetsB, targetSizesB, targetStridesB, newTargetOffsets,
newTargetSizes, newTargetStrides, targetMaxNbDims))) {
newTargetSizes, newTargetStrides,
std::bind(&DmaDimConfig::exceedsNbDims,
std::ref(targetDmaDimConfig), _1)))) {
return failure();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"

#define DEBUG_TYPE "iree-amdaie-dma-utils"

namespace mlir::iree_compiler::AMDAIE {

static bool isEqualConstantIntOrValueArrayFromIndices(
Expand Down Expand Up @@ -42,7 +44,7 @@ bool areAccessPatternsCombinable(const SmallVector<OpFoldResult> &offsetsA,
const SmallVector<OpFoldResult> &offsetsB,
const SmallVector<OpFoldResult> &sizesB,
const SmallVector<OpFoldResult> &stridesB,
size_t maxNbDims) {
function_ref<bool(size_t)> exceedsNbDims) {
assert(offsetsA.size() == sizesA.size() &&
"expected same number of source offsets and sizes");
assert(offsetsA.size() == stridesA.size() &&
Expand All @@ -59,8 +61,11 @@ bool areAccessPatternsCombinable(const SmallVector<OpFoldResult> &offsetsA,
// In case both access patterns have the same number of dimension, a new
// dimension will need to be added, so fail if there aren't enough
// dimensions.
if (offsetsA.size() == offsetsB.size() && offsetsA.size() + 1 > maxNbDims)
if (offsetsA.size() == offsetsB.size() &&
exceedsNbDims(offsetsA.size() + 1)) {
LLVM_DEBUG(llvm::dbgs() << "Exceeded maximum number of dimensions\n");
return false;
}

// Equality of the last N elements of the access patterns of A and B with N =
// min(sizeA, sizeB) results in some simple cases in which the access
Expand Down Expand Up @@ -196,7 +201,7 @@ LogicalResult combineAccessPatterns(RewriterBase &rewriter,
SmallVector<OpFoldResult> &newOffsets,
SmallVector<OpFoldResult> &newSizes,
SmallVector<OpFoldResult> &newStrides,
size_t maxNbDims) {
function_ref<bool(size_t)> exceedsNbDims) {
assert(offsetsA.size() == sizesA.size() &&
"expected same number of source offsets and sizes");
assert(offsetsA.size() == stridesA.size() &&
Expand All @@ -206,7 +211,7 @@ LogicalResult combineAccessPatterns(RewriterBase &rewriter,
assert(offsetsB.size() == stridesB.size() &&
"expected same number of source offsets and strides");
if (!areAccessPatternsCombinable(offsetsA, sizesA, stridesA, offsetsB, sizesB,
stridesB, maxNbDims)) {
stridesB, exceedsNbDims)) {
return failure();
}
if (offsetsA.empty() && offsetsB.empty()) return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ bool areAccessPatternsCombinable(const SmallVector<OpFoldResult> &offsetsA,
const SmallVector<OpFoldResult> &offsetsB,
const SmallVector<OpFoldResult> &sizesB,
const SmallVector<OpFoldResult> &stridesB,
size_t maxNbDims);
function_ref<bool(size_t)> exceedsNbDims);

/// Combine two access patterns into a single one. Assumes that access pattern A
/// belongs to a strided op which is ordered before the strided op B. Takes a
Expand All @@ -110,7 +110,7 @@ LogicalResult combineAccessPatterns(RewriterBase &rewriter,
SmallVector<OpFoldResult> &newOffsets,
SmallVector<OpFoldResult> &newSizes,
SmallVector<OpFoldResult> &newStrides,
size_t maxNbDims);
function_ref<bool(size_t)> exceedsNbDims);

/// Fold subsequent dimensions within a strided access pattern that describe a
/// single linear access. Returns `success` if folding took place.
Expand Down
Loading
Loading