Skip to content

Commit

Permalink
[CombineStridedOps] Add a combinable case (nod-ai#839)
Browse files Browse the repository at this point in the history
This PR adds a corner case for strided op combination. With this PR, the
following strides ops:

```
48 = amdaie.npu.dma_cpy_nd %8([] [] [], %31[0, %45] [32, 64] [128, 1]) : source_type = !amdaie.logicalobjectfifo<memref<16384xi32>>
%49 = amdaie.npu.dma_cpy_nd %8([] [] [], %31[32, %45] [96, 64] [128, 1]) : source_type = !amdaie.logicalobjectfifo<memref<16384xi32>>
```

can be combined as 

`%48 = amdaie.npu.dma_cpy_nd %8([] [] [], %31[0, %45] [128, 64] [128,
1]) : source_type = !amdaie.logicalobjectfifo<memref<16384xi32>>
`

Addressed review comments
nod-ai#826 (comment).
  • Loading branch information
yzhang93 authored Oct 9, 2024
1 parent c84cca0 commit c607072
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,14 @@ bool areAccessPatternsCombinable(const SmallVector<OpFoldResult> &offsetsA,
}
if (strideA != strideB) return false;
}

// Don't check the outermost dimension of size at this point.
SmallVector<OpFoldResult> innerSizesA;
SmallVector<OpFoldResult> innerSizesB;
std::copy(sizesA.begin() + 1, sizesA.end(), std::back_inserter(innerSizesA));
std::copy(sizesB.begin() + 1, sizesB.end(), std::back_inserter(innerSizesB));
for (auto &&[sizeA, sizeB] :
llvm::zip(llvm::reverse(sizesA), llvm::reverse(sizesB))) {
llvm::zip(llvm::reverse(innerSizesA), llvm::reverse(innerSizesB))) {
std::optional<int64_t> maybeSizeA = getConstantIntValue(sizeA);
std::optional<int64_t> maybeSizeB = getConstantIntValue(sizeB);
// Handle static and constant value with same int value.
Expand All @@ -71,6 +77,20 @@ bool areAccessPatternsCombinable(const SmallVector<OpFoldResult> &offsetsA,
if (sizeA != sizeB) return false;
}

// Edge case for sizesA[0] != sizesB[0].
if (offsetsB.size() == offsetsA.size() && sizesA[0] != sizesB[0]) {
std::optional<int64_t> constOffsetA = getConstantIntValue(offsetsA[0]);
std::optional<int64_t> constSizeA = getConstantIntValue(sizesA[0]);
std::optional<int64_t> constOffsetB = getConstantIntValue(offsetsB[0]);
std::optional<int64_t> constSizeB = getConstantIntValue(sizesB[0]);
if (constOffsetA && constOffsetB && constSizeA && constSizeB) {
int64_t offsetDiff = constOffsetB.value() - constOffsetA.value();
if (constSizeA.value() != offsetDiff) return false;
} else {
return false;
}
}

bool foundDiff{false};
for (auto iter : llvm::enumerate(
llvm::zip(llvm::reverse(offsetsA), llvm::reverse(offsetsB)))) {
Expand Down Expand Up @@ -169,40 +189,50 @@ LogicalResult combineAccessPatterns(RewriterBase &rewriter,
if (!size) return failure();
newSizes[0] = rewriter.getI64IntegerAttr(size.value() + 1);
} else {
// Sizes are the same, so add a new dimension with 'offset == 0', 'size ==
// 2' and 'stride == offsetDiff'.
newOffsets.push_back(rewriter.getI64IntegerAttr(0));
int64_t offsetDiff;
int64_t strideMultiplier;
for (auto iter : llvm::enumerate(llvm::zip(offsetsA, offsetsB))) {
const OpFoldResult &offsetA = std::get<0>(iter.value());
const OpFoldResult &offsetB = std::get<1>(iter.value());
newOffsets.push_back(offsetA);
if (offsetA != offsetB) {
std::optional<int64_t> constOffsetA = getConstantIntValue(offsetA);
std::optional<int64_t> constOffsetB = getConstantIntValue(offsetB);
if (!constOffsetA || !constOffsetB) {
return emitError(rewriter.getUnknownLoc())
<< "differing offsets should be constants";
}
offsetDiff = constOffsetB.value() - constOffsetA.value();
std::optional<int64_t> maybeStride =
getConstantIntValue(stridesA[iter.index()]);
if (!maybeStride) {
return emitError(rewriter.getUnknownLoc())
<< "no constant stride found at the same index where the "
"offset "
"difference occurs";
// Edge case for sizesA[0] != sizesB[0].
if (sizesA[0] != sizesB[0]) {
newOffsets = offsetsA;
newSizes = sizesA;
newStrides = stridesA;
std::optional<int64_t> sizeA = getConstantIntValue(sizesA[0]);
std::optional<int64_t> sizeB = getConstantIntValue(sizesB[0]);
if (!sizeA || !sizeB) return failure();
newSizes[0] = rewriter.getI64IntegerAttr(sizeA.value() + sizeB.value());
} else {
// All dims of sizes are the same, so add a new dimension with
// 'offset == 0', 'size == 2' and 'stride == offsetDiff'.
newOffsets.push_back(rewriter.getI64IntegerAttr(0));
int64_t offsetDiff;
int64_t strideMultiplier;
for (auto iter : llvm::enumerate(llvm::zip(offsetsA, offsetsB))) {
const OpFoldResult &offsetA = std::get<0>(iter.value());
const OpFoldResult &offsetB = std::get<1>(iter.value());
newOffsets.push_back(offsetA);
if (offsetA != offsetB) {
std::optional<int64_t> constOffsetA = getConstantIntValue(offsetA);
std::optional<int64_t> constOffsetB = getConstantIntValue(offsetB);
if (!constOffsetA || !constOffsetB) {
return emitError(rewriter.getUnknownLoc())
<< "differing offsets should be constants";
}
offsetDiff = constOffsetB.value() - constOffsetA.value();
std::optional<int64_t> maybeStride =
getConstantIntValue(stridesA[iter.index()]);
if (!maybeStride) {
return emitError(rewriter.getUnknownLoc())
<< "no constant stride found at the same index where the "
"offset "
"difference occurs";
}
strideMultiplier = maybeStride.value();
}
strideMultiplier = maybeStride.value();
}
newSizes.push_back(rewriter.getI64IntegerAttr(2));
newSizes.append(sizesA.begin(), sizesA.end());
newStrides.push_back(
rewriter.getI64IntegerAttr(offsetDiff * strideMultiplier));
newStrides.append(stridesA.begin(), stridesA.end());
}
newSizes.push_back(rewriter.getI64IntegerAttr(2));
newSizes.append(sizesA.begin(), sizesA.end());
newStrides.push_back(
rewriter.getI64IntegerAttr(offsetDiff * strideMultiplier));
newStrides.append(stridesA.begin(), stridesA.end());
;
}
assert(newOffsets.size() == newSizes.size() &&
"expected same number of new offsets and sizes");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ LogicalResult bufferizeTemporaryMemrefs(Operation *parentOp) {
});
}


// Note: we don't erase allocs/deallocs, we leave this for canonicalization.

return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ TEST_F(AccessPatternCombinationTest, CombinableAccessPatterns) {
EXPECT_TRUE(checkAreAccessPatternsCombinable({0, 2, 0}, {16, 16, 32},
{32, 64, 1}, {0, 2, 32},
{16, 16, 32}, {32, 64, 1}, 4));
EXPECT_TRUE(checkAreAccessPatternsCombinable({32, 0}, {64, 64}, {128, 1},
{96, 0}, {32, 64}, {128, 1}, 4));
// size(A) > size(B)
EXPECT_TRUE(checkAreAccessPatternsCombinable(
{0, 0, 0}, {2, 16, 32}, {32, 64, 1}, {0, 64}, {16, 32}, {64, 1}, 4));
Expand Down Expand Up @@ -168,6 +170,12 @@ TEST_F(AccessPatternCombinationTest, NonCombinableAccessPatterns) {
{0, 0}, {16, 32}, {64, 1}, {0, 0, 96}, {2, 16, 32}, {32, 64, 1}, 4));
EXPECT_FALSE(checkAreAccessPatternsCombinable(
{0, 0}, {16, 32}, {64, 1}, {0, 1, 0}, {2, 16, 32}, {32, 64, 1}, 4));

// size(A) == size(B) Incompatible offset
EXPECT_FALSE(checkAreAccessPatternsCombinable(
{32, 0}, {64, 64}, {128, 1}, {32, 0}, {32, 64}, {128, 1}, 4));
EXPECT_FALSE(checkAreAccessPatternsCombinable(
{32, 0}, {32, 64}, {128, 1}, {96, 0}, {64, 64}, {128, 1}, 4));
}

TEST_F(AccessPatternCombinationTest, CombineAccessPatterns) {
Expand Down Expand Up @@ -197,6 +205,8 @@ TEST_F(AccessPatternCombinationTest, CombineAccessPatterns) {
checkCombineAccessPatterns({8, 0, 0}, {16, 8, 16}, {16, 8, 1}, {40, 0, 0},
{16, 8, 16}, {16, 8, 1}, {0, 8, 0, 0},
{2, 16, 8, 16}, {512, 16, 8, 1}, 4);
checkCombineAccessPatterns({32, 0}, {64, 64}, {128, 1}, {96, 0}, {32, 64},
{128, 1}, {32, 0}, {96, 64}, {128, 1}, 4);
// size(A) > size(B)
checkCombineAccessPatterns({0, 0}, {2, 32}, {64, 1}, {128}, {32}, {1}, {0, 0},
{3, 32}, {64, 1}, 3);
Expand Down Expand Up @@ -255,6 +265,10 @@ TEST_F(AccessPatternCombinationTest, FailCombineAccessPatterns) {
{3, 32}, {64, 1}, 3, false);
checkCombineAccessPatterns({0}, {32}, {1}, {0, 96}, {2, 32}, {64, 1}, {0, 0},
{3, 32}, {64, 1}, 3, false);

// size(A) == size(B) Incompatible offset
checkCombineAccessPatterns({32, 0}, {32, 64}, {128, 1}, {96, 0}, {64, 64},
{128, 1}, {32, 0}, {96, 64}, {128, 1}, 4, false);
}

} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,28 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}

// -----

// CHECK-LABEL: @combine_source_same_dims_diff_sizes
// CHECK: %[[CONNECTION:.+]] = amdaie.connection
// CHECK: amdaie.npu.dma_cpy_nd %[[CONNECTION]]([] [] [], [0, 0] [128, 64] [128, 1])
// CHECK-NOT: amdaie.npu.dma_cpy_nd
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @combine_source_same_dims_diff_sizes(%arg0: !amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>>, %arg1: !amdaie.logicalobjectfifo<memref<128x128xi32>>) {
amdaie.workgroup {
%0 = amdaie.connection(%arg0, %arg1) : (!amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>>, !amdaie.logicalobjectfifo<memref<128x128xi32>>)
amdaie.controlcode {
amdaie.npu.dma_cpy_nd %0([] [] [], [0, 0] [32, 64] [128, 1])
amdaie.npu.dma_cpy_nd %0([] [] [], [32, 0] [64, 64] [128, 1])
amdaie.npu.dma_cpy_nd %0([] [] [], [96, 0] [32, 64] [128, 1])
amdaie.end
}
}
return
}
}

// -----

// CHECK-LABEL: @combine_source_values
// CHECK: %[[CONNECTION:.+]] = amdaie.connection
// CHECK: amdaie.npu.dma_cpy_nd %[[CONNECTION]]([] [] [], [0, 0, 0, 0] [2, 16, 8, 16] [32, 32, 8, 1])
Expand Down Expand Up @@ -332,6 +354,28 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}

// -----

// CHECK-LABEL: @combine_target_same_dims_diff_sizes
// CHECK: %[[CONNECTION:.+]] = amdaie.connection
// CHECK: amdaie.npu.dma_cpy_nd %[[CONNECTION]]([0, 0] [128, 64] [128, 1], [] [] [])
// CHECK-NOT: amdaie.npu.dma_cpy_nd
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @combine_target_same_dims_diff_sizes(%arg0: !amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>>, %arg1: !amdaie.logicalobjectfifo<memref<128x128xi32>>) {
amdaie.workgroup {
%0 = amdaie.connection(%arg0, %arg1) : (!amdaie.logicalobjectfifo<memref<2048xi32, 1 : i32>>, !amdaie.logicalobjectfifo<memref<128x128xi32>>)
amdaie.controlcode {
amdaie.npu.dma_cpy_nd %0([0, 0] [32, 64] [128, 1], [] [] [])
amdaie.npu.dma_cpy_nd %0([32, 0] [64, 64] [128, 1], [] [] [])
amdaie.npu.dma_cpy_nd %0([96, 0] [32, 64] [128, 1], [] [] [])
amdaie.end
}
}
return
}
}

// -----

// CHECK-LABEL: @combine_target_diff_dims
// CHECK: %[[CONNECTION:.+]] = amdaie.connection
// CHECK: amdaie.npu.dma_cpy_nd %[[CONNECTION]]([0, 0, 0, 32] [3, 16, 8, 16] [64, 32, 8, 1], [] [] [])
Expand Down

0 comments on commit c607072

Please sign in to comment.