diff --git a/lib/Transforms/OptimizeTranspose.cpp b/lib/Transforms/OptimizeTranspose.cpp index ac9a6d8a5..b91bc5b69 100644 --- a/lib/Transforms/OptimizeTranspose.cpp +++ b/lib/Transforms/OptimizeTranspose.cpp @@ -130,27 +130,26 @@ struct LoadTransposeAnalysis { return true; }; - // Helper to visit CreateNdDescOp and find all LoadNdOps that use it. - void visitCreateNdDescOp(xegpu::CreateNdDescOp createNdDescOp, - llvm::SmallVector &loadNdOpsFound) { + // Helper to visit CreateNdDescOp and UpdateNdOffsetOp + // and find all LoadNdOps that use it. + void visitCreateNdDescOrUpdateNdOffsetOp( + mlir::Operation *op, llvm::SmallVector &loadNdOpsFound) { llvm::SmallSet worklist; - worklist.insert(createNdDescOp); + worklist.insert(op); while (!worklist.empty()) { auto currOp = *worklist.begin(); worklist.erase(currOp); // We found a LoadNdOp. if (auto loadNdOp = llvm::dyn_cast_if_present(currOp)) { loadNdOpsFound.push_back(loadNdOp); - } - // Process all users of the current op. - else { + } else { // Process all users of the current op. for (auto user : currOp->getUsers()) { // If current user is a forOp, we need to get the block argument. if (auto forOp = llvm::dyn_cast_if_present(user)) { auto blockArg = imex::getArgForOperand(forOp, currOp->getResult(0)); for (auto user : blockArg.getUsers()) worklist.insert(user); - } else { + } else if (!llvm::isa(user)) { worklist.insert(user); } } @@ -242,10 +241,14 @@ struct LoadTransposeAnalysis { public: LoadTransposeAnalysis(Operation *op) { - op->walk([&](xegpu::CreateNdDescOp createNdDescOp) -> WalkResult { + op->walk([&](mlir::Operation *targetOp) -> WalkResult { + if (!llvm::isa(targetOp) && + !llvm::isa(targetOp)) + return WalkResult::skip(); + llvm::SmallVector loadNdOpsFound; // Find all LoadNdOps that use this CreateNdDescOp. - visitCreateNdDescOp(createNdDescOp, loadNdOpsFound); + visitCreateNdDescOrUpdateNdOffsetOp(targetOp, loadNdOpsFound); // If no LoadNdOps or more than one LoadNdOps are found, we skip. if (loadNdOpsFound.size() != 1) return WalkResult::skip(); @@ -283,7 +286,9 @@ struct LoadTransposeAnalysis { fusionCandidates.insert(loadNdOp); // Source CreateNdDescOp is considered for array length adjustment if // array_length > 1. - if (createNdDescOp.getTensorDesc().getType().getArrayLength() > 1) + auto createNdDescOp = llvm::dyn_cast(targetOp); + if (createNdDescOp && + createNdDescOp.getTensorDesc().getType().getArrayLength() > 1) arrayLenAdjustmentCandidates.insert(createNdDescOp); return WalkResult::advance(); }); diff --git a/test/Transforms/xegpu-optimize-transpose.mlir b/test/Transforms/xegpu-optimize-transpose.mlir index 2dbb381ee..5488b7f70 100644 --- a/test/Transforms/xegpu-optimize-transpose.mlir +++ b/test/Transforms/xegpu-optimize-transpose.mlir @@ -490,3 +490,33 @@ func.func @test_transpose(%arg0: memref<16x16xf16>, %arg1: memref<8x32xf16>) { //CHECK: %[[r17:.*]] = xegpu.create_nd_tdesc %[[arg1]][0, %[[r16]]] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16> //CHECK: xegpu.store_nd %[[r15]], %[[r17]] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> } + +// ----- + //CHECK: func.func @test_load_update_nd_offset(%[[arg0:.*]]: memref<16x16xf16>, %[[arg1:.*]]: memref<16x32xf16>, %[[arg2:.*]]: memref<16x32xf32>) + func.func @test_load_update_nd_offset(%arg0: memref<16x16xf16>, %arg1: memref<16x32xf16>, %arg2: memref<16x32xf32>) { + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<16x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<16x32xf16> -> !xegpu.tensor_desc<16x16xf16> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + //CHECK: %{{.*}} = xegpu.load_nd %{{.*}} <{transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %4 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16> + %5 = vector.shape_cast %4 {packed} : vector<16x16xf16> to vector<256xf16> + %6 = vector.shuffle %5, %5 [0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31, 32, 48, 33, 49, 34, 50, 35, 51, 36, 52, 37, 53, 38, 54, 39, 55, 40, 56, 41, 57, 42, 58, 43, 59, 44, 60, 45, 61, 46, 62, 47, 63, 64, 80, 65, 81, 66, 82, 67, 83, 68, 84, 69, 85, 70, 86, 71, 87, 72, 88, 73, 89, 74, 90, 75, 91, 76, 92, 77, 93, 78, 94, 79, 95, 96, 112, 97, 113, 98, 114, 99, 115, 100, 116, 101, 117, 102, 118, 103, 119, 104, 120, 105, 121, 106, 122, 107, 123, 108, 124, 109, 125, 110, 126, 111, 127, 128, 144, 129, 145, 130, 146, 131, 147, 132, 148, 133, 149, 134, 150, 135, 151, 136, 152, 137, 153, 138, 154, 139, 155, 140, 156, 141, 157, 142, 158, 143, 159, 160, 176, 161, 177, 162, 178, 163, 179, 164, 180, 165, 181, 166, 182, 167, 183, 168, 184, 169, 185, 170, 186, 171, 187, 172, 188, 173, 189, 174, 190, 175, 191, 192, 208, 193, 209, 194, 210, 195, 211, 196, 212, 197, 213, 198, 214, 199, 215, 200, 216, 201, 217, 202, 218, 203, 219, 204, 220, 205, 221, 206, 222, 207, 223, 224, 240, 225, 241, 226, 242, 227, 243, 228, 244, 229, 245, 230, 246, 231, 247, 232, 248, 233, 249, 234, 250, 235, 251, 236, 252, 237, 253, 238, 254, 239, 255] {packed} : vector<256xf16>, vector<256xf16> + %7 = vector.shape_cast %6 {packed} : vector<256xf16> to vector<8x16x2xf16> + %8 = xegpu.dpas %2, %7 : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + %9 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %8, %9 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %10 = xegpu.update_nd_offset %0, [8, 0] : !xegpu.tensor_desc<8x16xf16> + %11 = xegpu.update_nd_offset %1, [0, 16] : !xegpu.tensor_desc<16x16xf16> + %12 = xegpu.load_nd %10 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + //CHECK: %{{.*}} = xegpu.load_nd %{{.*}} <{transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %13 = xegpu.load_nd %11 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %14 = vector.transpose %13, [1, 0] : vector<16x16xf16> to vector<16x16xf16> + %15 = vector.shape_cast %14 {packed} : vector<16x16xf16> to vector<256xf16> + %16 = vector.shuffle %15, %15 [0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31, 32, 48, 33, 49, 34, 50, 35, 51, 36, 52, 37, 53, 38, 54, 39, 55, 40, 56, 41, 57, 42, 58, 43, 59, 44, 60, 45, 61, 46, 62, 47, 63, 64, 80, 65, 81, 66, 82, 67, 83, 68, 84, 69, 85, 70, 86, 71, 87, 72, 88, 73, 89, 74, 90, 75, 91, 76, 92, 77, 93, 78, 94, 79, 95, 96, 112, 97, 113, 98, 114, 99, 115, 100, 116, 101, 117, 102, 118, 103, 119, 104, 120, 105, 121, 106, 122, 107, 123, 108, 124, 109, 125, 110, 126, 111, 127, 128, 144, 129, 145, 130, 146, 131, 147, 132, 148, 133, 149, 134, 150, 135, 151, 136, 152, 137, 153, 138, 154, 139, 155, 140, 156, 141, 157, 142, 158, 143, 159, 160, 176, 161, 177, 162, 178, 163, 179, 164, 180, 165, 181, 166, 182, 167, 183, 168, 184, 169, 185, 170, 186, 171, 187, 172, 188, 173, 189, 174, 190, 175, 191, 192, 208, 193, 209, 194, 210, 195, 211, 196, 212, 197, 213, 198, 214, 199, 215, 200, 216, 201, 217, 202, 218, 203, 219, 204, 220, 205, 221, 206, 222, 207, 223, 224, 240, 225, 241, 226, 242, 227, 243, 228, 244, 229, 245, 230, 246, 231, 247, 232, 248, 233, 249, 234, 250, 235, 251, 236, 252, 237, 253, 238, 254, 239, 255] {packed} : vector<256xf16>, vector<256xf16> + %17 = vector.shape_cast %16 {packed} : vector<256xf16> to vector<8x16x2xf16> + %18 = xegpu.dpas %12, %17 : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + %19 = xegpu.update_nd_offset %9, [0, 16] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %18, %19 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + return + }