Skip to content

Commit

Permalink
Fix a bug visitCreateNdDescOp which mistakenly treated CreateNdDesc is
Browse files Browse the repository at this point in the history
used twice because of user of UpdateNdOffset is also check in the chain.
  • Loading branch information
chencha3 committed Dec 14, 2024
1 parent bee50a0 commit b094563
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 11 deletions.
27 changes: 16 additions & 11 deletions lib/Transforms/OptimizeTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,27 +130,26 @@ struct LoadTransposeAnalysis {
return true;
};

// Helper to visit CreateNdDescOp and find all LoadNdOps that use it.
void visitCreateNdDescOp(xegpu::CreateNdDescOp createNdDescOp,
llvm::SmallVector<Operation *> &loadNdOpsFound) {
// Helper to visit CreateNdDescOp and UpdateNdOffsetOp
// and find all LoadNdOps that use it.
void visitCreateNdDescOrUpdateNdOffsetOp(
mlir::Operation *op, llvm::SmallVector<Operation *> &loadNdOpsFound) {
llvm::SmallSet<Operation *, 8> worklist;
worklist.insert(createNdDescOp);
worklist.insert(op);
while (!worklist.empty()) {
auto currOp = *worklist.begin();
worklist.erase(currOp);
// We found a LoadNdOp.
if (auto loadNdOp = llvm::dyn_cast_if_present<xegpu::LoadNdOp>(currOp)) {
loadNdOpsFound.push_back(loadNdOp);
}
// Process all users of the current op.
else {
} else { // Process all users of the current op.
for (auto user : currOp->getUsers()) {
// If current user is a forOp, we need to get the block argument.
if (auto forOp = llvm::dyn_cast_if_present<scf::ForOp>(user)) {
auto blockArg = imex::getArgForOperand(forOp, currOp->getResult(0));
for (auto user : blockArg.getUsers())
worklist.insert(user);
} else {
} else if (!llvm::isa<xegpu::UpdateNdOffsetOp>(user)) {
worklist.insert(user);
}
}
Expand Down Expand Up @@ -242,10 +241,14 @@ struct LoadTransposeAnalysis {

public:
LoadTransposeAnalysis(Operation *op) {
op->walk([&](xegpu::CreateNdDescOp createNdDescOp) -> WalkResult {
op->walk([&](mlir::Operation *targetOp) -> WalkResult {
if (!llvm::isa<xegpu::CreateNdDescOp>(targetOp) &&
!llvm::isa<xegpu::UpdateNdOffsetOp>(targetOp))
return WalkResult::skip();

llvm::SmallVector<Operation *> loadNdOpsFound;
// Find all LoadNdOps that use this CreateNdDescOp.
visitCreateNdDescOp(createNdDescOp, loadNdOpsFound);
visitCreateNdDescOrUpdateNdOffsetOp(targetOp, loadNdOpsFound);
// If no LoadNdOps or more than one LoadNdOps are found, we skip.
if (loadNdOpsFound.size() != 1)
return WalkResult::skip();
Expand Down Expand Up @@ -283,7 +286,9 @@ struct LoadTransposeAnalysis {
fusionCandidates.insert(loadNdOp);
// Source CreateNdDescOp is considered for array length adjustment if
// array_length > 1.
if (createNdDescOp.getTensorDesc().getType().getArrayLength() > 1)
auto createNdDescOp = llvm::dyn_cast<xegpu::CreateNdDescOp>(targetOp);
if (createNdDescOp &&
createNdDescOp.getTensorDesc().getType().getArrayLength() > 1)
arrayLenAdjustmentCandidates.insert(createNdDescOp);
return WalkResult::advance();
});
Expand Down
30 changes: 30 additions & 0 deletions test/Transforms/xegpu-optimize-transpose.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -490,3 +490,33 @@ func.func @test_transpose(%arg0: memref<16x16xf16>, %arg1: memref<8x32xf16>) {
//CHECK: %[[r17:.*]] = xegpu.create_nd_tdesc %[[arg1]][0, %[[r16]]] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16>
//CHECK: xegpu.store_nd %[[r15]], %[[r17]] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
}

// -----
//CHECK: func.func @test_load_update_nd_offset(%[[arg0:.*]]: memref<16x16xf16>, %[[arg1:.*]]: memref<16x32xf16>, %[[arg2:.*]]: memref<16x32xf32>)
func.func @test_load_update_nd_offset(%arg0: memref<16x16xf16>, %arg1: memref<16x32xf16>, %arg2: memref<16x32xf32>) {
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<16x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%1 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<16x32xf16> -> !xegpu.tensor_desc<16x16xf16>
%2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
//CHECK: %{{.*}} = xegpu.load_nd %{{.*}} <{transpose = array<i64: 1, 0>, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
%3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
%4 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
%5 = vector.shape_cast %4 {packed} : vector<16x16xf16> to vector<256xf16>
%6 = vector.shuffle %5, %5 [0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31, 32, 48, 33, 49, 34, 50, 35, 51, 36, 52, 37, 53, 38, 54, 39, 55, 40, 56, 41, 57, 42, 58, 43, 59, 44, 60, 45, 61, 46, 62, 47, 63, 64, 80, 65, 81, 66, 82, 67, 83, 68, 84, 69, 85, 70, 86, 71, 87, 72, 88, 73, 89, 74, 90, 75, 91, 76, 92, 77, 93, 78, 94, 79, 95, 96, 112, 97, 113, 98, 114, 99, 115, 100, 116, 101, 117, 102, 118, 103, 119, 104, 120, 105, 121, 106, 122, 107, 123, 108, 124, 109, 125, 110, 126, 111, 127, 128, 144, 129, 145, 130, 146, 131, 147, 132, 148, 133, 149, 134, 150, 135, 151, 136, 152, 137, 153, 138, 154, 139, 155, 140, 156, 141, 157, 142, 158, 143, 159, 160, 176, 161, 177, 162, 178, 163, 179, 164, 180, 165, 181, 166, 182, 167, 183, 168, 184, 169, 185, 170, 186, 171, 187, 172, 188, 173, 189, 174, 190, 175, 191, 192, 208, 193, 209, 194, 210, 195, 211, 196, 212, 197, 213, 198, 214, 199, 215, 200, 216, 201, 217, 202, 218, 203, 219, 204, 220, 205, 221, 206, 222, 207, 223, 224, 240, 225, 241, 226, 242, 227, 243, 228, 244, 229, 245, 230, 246, 231, 247, 232, 248, 233, 249, 234, 250, 235, 251, 236, 252, 237, 253, 238, 254, 239, 255] {packed} : vector<256xf16>, vector<256xf16>
%7 = vector.shape_cast %6 {packed} : vector<256xf16> to vector<8x16x2xf16>
%8 = xegpu.dpas %2, %7 : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
%9 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
xegpu.store_nd %8, %9 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
%10 = xegpu.update_nd_offset %0, [8, 0] : !xegpu.tensor_desc<8x16xf16>
%11 = xegpu.update_nd_offset %1, [0, 16] : !xegpu.tensor_desc<16x16xf16>
%12 = xegpu.load_nd %10 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
//CHECK: %{{.*}} = xegpu.load_nd %{{.*}} <{transpose = array<i64: 1, 0>, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
%13 = xegpu.load_nd %11 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
%14 = vector.transpose %13, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
%15 = vector.shape_cast %14 {packed} : vector<16x16xf16> to vector<256xf16>
%16 = vector.shuffle %15, %15 [0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31, 32, 48, 33, 49, 34, 50, 35, 51, 36, 52, 37, 53, 38, 54, 39, 55, 40, 56, 41, 57, 42, 58, 43, 59, 44, 60, 45, 61, 46, 62, 47, 63, 64, 80, 65, 81, 66, 82, 67, 83, 68, 84, 69, 85, 70, 86, 71, 87, 72, 88, 73, 89, 74, 90, 75, 91, 76, 92, 77, 93, 78, 94, 79, 95, 96, 112, 97, 113, 98, 114, 99, 115, 100, 116, 101, 117, 102, 118, 103, 119, 104, 120, 105, 121, 106, 122, 107, 123, 108, 124, 109, 125, 110, 126, 111, 127, 128, 144, 129, 145, 130, 146, 131, 147, 132, 148, 133, 149, 134, 150, 135, 151, 136, 152, 137, 153, 138, 154, 139, 155, 140, 156, 141, 157, 142, 158, 143, 159, 160, 176, 161, 177, 162, 178, 163, 179, 164, 180, 165, 181, 166, 182, 167, 183, 168, 184, 169, 185, 170, 186, 171, 187, 172, 188, 173, 189, 174, 190, 175, 191, 192, 208, 193, 209, 194, 210, 195, 211, 196, 212, 197, 213, 198, 214, 199, 215, 200, 216, 201, 217, 202, 218, 203, 219, 204, 220, 205, 221, 206, 222, 207, 223, 224, 240, 225, 241, 226, 242, 227, 243, 228, 244, 229, 245, 230, 246, 231, 247, 232, 248, 233, 249, 234, 250, 235, 251, 236, 252, 237, 253, 238, 254, 239, 255] {packed} : vector<256xf16>, vector<256xf16>
%17 = vector.shape_cast %16 {packed} : vector<256xf16> to vector<8x16x2xf16>
%18 = xegpu.dpas %12, %17 : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
%19 = xegpu.update_nd_offset %9, [0, 16] : !xegpu.tensor_desc<8x16xf32>
xegpu.store_nd %18, %19 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
return
}

0 comments on commit b094563

Please sign in to comment.