Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LinalgToXeGPU] Lower linalg.matmul_transpose_b into xegpu.dpas #347

Merged
merged 11 commits into from
Sep 30, 2024
142 changes: 123 additions & 19 deletions lib/gc/Transforms/GPU/LinalgToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ static bool isDPASCompatible(linalg::LinalgOp linalgOp, int kTile,
ArrayRef<int64_t> dpasTile) {
if (!(isa<linalg::MatmulOp>(linalgOp) ||
isa<linalg::BatchReduceMatmulOp>(linalgOp) ||
isa<linalg::MatmulTransposeBOp>(linalgOp) ||
isa<linalg::GenericOp>(linalgOp))) {
return false;
}
Expand Down Expand Up @@ -633,12 +634,11 @@ static SmallVector<Value> updateTilesOffsets(PatternRewriter &rewriter,
//
// The descriptor sub-tiles are ordered in row-major fashion with respect to the
// whole load tile.
static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,
Location loc, Value src,
ArrayRef<int64_t> loadShape,
ArrayRef<int64_t> loadOffsets,
ArrayRef<int64_t> descTile,
int arrayLength = 1) {
static SmallVector<Value>
createDescriptorTiles(PatternRewriter &rewriter, Location loc, Value src,
ArrayRef<int64_t> loadShape,
ArrayRef<int64_t> loadOffsets, ArrayRef<int64_t> descTile,
int arrayLength = 1, bool transpose = false) {
assert(arrayLength == 1 && "Array descriptors are not supported");

auto type = cast<ShapedType>(src.getType());
Expand Down Expand Up @@ -669,6 +669,9 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,
Value newRowOffs = rewriter.create<arith::ConstantIndexOp>(loc, i);
for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) {
Value newColOffs = rewriter.create<arith::ConstantIndexOp>(loc, j);
if (transpose) {
std::swap(newRowOffs, newColOffs);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changes iteration dimension for B chunks

}
auto tile = rewriter
.create<xegpu::UpdateNdOffsetOp>(
loc, descType, rootTile,
Expand All @@ -693,7 +696,8 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,
static SmallVector<Value> createCoarseDscTiles(PatternRewriter &rewriter,
Location loc, Value src,
ArrayRef<int64_t> sgTile,
bool isVnni) {
bool isVnni,
bool transpose = false) {
assert(sgTile.size() <= 2 &&
"Require at most 2D tile size for eltwise lowering");

Expand Down Expand Up @@ -727,7 +731,8 @@ static SmallVector<Value> createCoarseDscTiles(PatternRewriter &rewriter,
// NOLINTEND

return createDescriptorTiles(rewriter, loc, src, sgTile2D, {0, 0},
{sgLoadRows, sgLoadCols}, arrayLength);
{sgLoadRows, sgLoadCols}, arrayLength,
transpose);
}

// Return vector type with specified VNNI shape.
Expand All @@ -745,7 +750,8 @@ static SmallVector<Value>
loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,
xegpu::CachePolicyAttr hint,
std::optional<VnniConfig> vnniConf = std::nullopt,
DenseI64ArrayAttr transpose = nullptr) {
DenseI64ArrayAttr transpose = nullptr,
IntegerAttr transpose_bit = nullptr) {
// Assume all tiles have the same shape.
auto tileType = cast<xegpu::TensorDescType>(loadTiles[0].getType());
assert(llvm::all_of(loadTiles,
Expand All @@ -760,7 +766,6 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,
*vnniConf);
packedAttr = mlir::UnitAttr::get(rewriter.getContext());
}
IntegerAttr transpose_bit = nullptr;
SmallVector<Value> loadVec;
for (auto tile : loadTiles) {

Expand Down Expand Up @@ -860,13 +865,82 @@ extractVecSubTiles(PatternRewriter &rewriter, Location loc,
return subTiles;
}

// Checks whether the given `matmulOperand` is produced by a
// `linalg::TransposeOp` and ensures that the transpose result is only used by
// valid operations, such as `linalg::MatmulOp`, `linalg::BatchReduceMatmulOp`,
// or `linalg::GenericOp`.
//
// If a valid transpose operation is found, the function records it for later
// removal and returns the operand of the transpose operation as the new matrix
// multiplication operand.
static FailureOr<Value> findAndReplaceTranspose(const Value &matmulOperand,
size_t operandIdx,
PatternRewriter &rewriter) {
auto defOp = matmulOperand.getDefiningOp();
if (!defOp) {
return failure();
}
linalg::TransposeOp transposeOp = nullptr;

for (auto x : defOp->getUsers()) {
if (isa<linalg::TransposeOp>(x)) {
if (transposeOp) {
return rewriter.notifyMatchFailure(
transposeOp, "Only one transpose operation is allowed");
}

transposeOp = dyn_cast<linalg::TransposeOp>(x);

auto transposeRes = transposeOp.getDpsInits()[0];
// verify that there are no other users of the transpose result
// rather than our matmul
for (auto trUser : transposeRes.getUsers()) {
if (isa<linalg::MatmulOp>(trUser) ||
isa<linalg::BatchReduceMatmulOp>(trUser) ||
isa<linalg::GenericOp>(trUser)) {
auto matmulOp = dyn_cast<linalg::LinalgOp>(trUser);
auto actualMatmulOperand = matmulOp.getDpsInputs()[operandIdx];
if (actualMatmulOperand != matmulOperand) {
return rewriter.notifyMatchFailure(
trUser,
"Transpose result is used by more than one matmul operation");
}
} else if (isa<memref::DeallocOp>(trUser)) {
// allow deallocs as users
continue;
} else if (isa<linalg::TransposeOp>(trUser)) {
// check if it's the same transpose as we're processing
if (!mlir::OperationEquivalence::isEquivalentTo(trUser, transposeOp,
/*flags=*/nullptr)) {
return rewriter.notifyMatchFailure(
trUser, "Only one transpose operation is allowed");
}
continue;
} else {
return rewriter.notifyMatchFailure(
trUser,
"Transpose result is not allowed to be used by this operation");
}
}
}
}
if (transposeOp) {
auto ret = transposeOp.getDpsInputs()[0];
rewriter.eraseOp(transposeOp);
return ret;
}
return rewriter.notifyMatchFailure(
defOp, "No transpose operation producing the operand was found");
}

// Create XeGPU DPAS kernel out of GEMM-like operation.
static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
ArrayRef<int64_t> dpasTile, int kTile,
int prefetchStages,
PatternRewriter &rewriter) {
assert((isa<linalg::MatmulOp>(linalgOp) ||
isa<linalg::BatchReduceMatmulOp>(linalgOp) ||
isa<linalg::MatmulTransposeBOp>(linalgOp) ||
isa<linalg::GenericOp>(linalgOp)) &&
"Requires a GEMM-like op for DPAS lowering");

Expand All @@ -877,6 +951,17 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
auto matB = linalgOp.getDpsInputs()[1];
auto matC = linalgOp.getDpsInits()[0];

bool transposeB = false;
if (isa<linalg::MatmulTransposeBOp>(linalgOp)) {
transposeB = true;
} else {
auto newMatB = findAndReplaceTranspose(matB, /*operandIdx=*/1, rewriter);
if (!failed(newMatB)) {
matB = *newMatB;
transposeB = true;
}
}

auto typeA = cast<ShapedType>(matA.getType());
auto typeC = cast<ShapedType>(matC.getType());

Expand Down Expand Up @@ -961,7 +1046,8 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,

// Create B sub-tiles.
SmallVector<Value> tilesB =
createCoarseDscTiles(rewriter, loc, matB, {kTile, dimN}, /*isVnni=*/true);
createCoarseDscTiles(rewriter, loc, matB, {kTile, dimN},
/*isVnni=*/true, transposeB);

// Create input prefetch tiles.
int64_t numThreads = 1;
Expand Down Expand Up @@ -997,7 +1083,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
{dimM, dimN}, kTile);
auto prefetchDescB = createGemmCoopPrefetchTile(
rewriter, linalgOp, /*inputPos=*/1, numThreads, {blockRows, blockCols},
{dimM, dimN}, kTile);
(transposeB) ? std::vector<int32_t>{dimM, dimN}
: std::vector<int32_t>{dimN, dimM},
kTile);

if (succeeded(prefetchDescA) && succeeded(prefetchDescB)) {
prefetchA = prefetchDescA->getResult();
Expand All @@ -1012,7 +1100,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
prefetchA = updateTilesOffsets(rewriter, loc, ValueRange{prefetchA},
{0, kTile})[0];
prefetchB = updateTilesOffsets(rewriter, loc, ValueRange{prefetchB},
{kTile, 0})[0];
(transposeB)
? std::vector<int64_t>{0, kTile}
: std::vector<int64_t>{kTile, 0})[0];
}
} else {
// Disable coop prefetching on failure.
Expand Down Expand Up @@ -1083,15 +1173,26 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint);
auto tileTypeA = cast<xegpu::TensorDescType>(tilesA[0].getType());

DenseI64ArrayAttr transpose = nullptr;
IntegerAttr transpose_bit = nullptr;

if (transposeB) {
transpose_bit = rewriter.getIntegerAttr(rewriter.getIntegerType(32), 32);
transpose = DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0});
}

// Load B sub-tiles.
SmallVector<Value> loadVecB =
loadNdDescTiles(rewriter, loc, tilesB, readCacheHint, vnniConfB);
loadNdDescTiles(rewriter, loc, tilesB, readCacheHint, vnniConfB,
transpose, transpose_bit);
auto tileTypeB = cast<xegpu::TensorDescType>(tilesB[0].getType());

// Update offsets of the input tiles.
// Shift along the reduction dimension.
tilesA = updateTilesOffsets(rewriter, loc, tilesA, {0, kTile});
tilesB = updateTilesOffsets(rewriter, loc, tilesB, {kTile, 0});
tilesB = updateTilesOffsets(rewriter, loc, tilesB,
transposeB ? std::vector<int64_t>{0, kTile}
: std::vector<int64_t>{kTile, 0});

// Prefetch the next set of input tiles.
if (isCoopPrefetch) {
Expand All @@ -1101,7 +1202,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
prefetchA =
updateTilesOffsets(rewriter, loc, ValueRange{prefetchA}, {0, kTile})[0];
prefetchB =
updateTilesOffsets(rewriter, loc, ValueRange{prefetchB}, {kTile, 0})[0];
updateTilesOffsets(rewriter, loc, ValueRange{prefetchB},
transposeB ? std::vector<int64_t>{0, kTile}
: std::vector<int64_t>{kTile, 0})[0];
} else {
// Apply naive prefetching for each subgroup separately.
prefetchTiles(rewriter, loc, tilesA, readCacheHint);
Expand Down Expand Up @@ -1288,7 +1391,7 @@ struct ConvertGemmLikeToXeGPU : public OpRewritePattern<LinalgOpTy> {
// Constrain conversion to the supported GEMM-like ops.
static_assert(
llvm::is_one_of<LinalgOpTy, linalg::MatmulOp, linalg::BatchReduceMatmulOp,
linalg::GenericOp>::value);
linalg::GenericOp, linalg::MatmulTransposeBOp>::value);

ConvertGemmLikeToXeGPU(MLIRContext *ctx, LinalgToXeGPUOptions options)
: OpRewritePattern<LinalgOpTy>(ctx), options(options) {}
Expand Down Expand Up @@ -1495,8 +1598,9 @@ struct ConvertMemoryFillToXeGPU : public OpRewritePattern<LinalgOpTy> {
void populateLinalgGemmToXeGPUPatterns(RewritePatternSet &patterns,
LinalgToXeGPUOptions options) {
patterns.add<ConvertGemmLikeToXeGPU<linalg::MatmulOp>,
ConvertGemmLikeToXeGPU<linalg::GenericOp>>(patterns.getContext(),
options);
ConvertGemmLikeToXeGPU<linalg::GenericOp>,
ConvertGemmLikeToXeGPU<linalg::MatmulTransposeBOp>>(
patterns.getContext(), options);
}

void populateLinalgEltwiseToXeGPUPatterns(RewritePatternSet &patterns,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file | FileCheck %s

module {
func.func @matmul_transpose_b_sep(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf16>) {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c1024 = arith.constant 1024 : index
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1024x1024xf16>
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) {
%subview_0 = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<1024x1024xf16> to memref<32x32xf16, strided<[1024, 1], offset: ?>>
%subview_1 = memref.subview %arg0[%arg3, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>>
%subview_2 = memref.subview %arg1[%arg4, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>>
%subview_3 = memref.subview %alloc[0, %arg4] [1024, 32] [1, 1] : memref<1024x1024xf16> to memref<1024x32xf16, strided<[1024, 1], offset: ?>>
linalg.transpose ins(%subview_2 : memref<32x1024xf16, strided<[1024, 1], offset: ?>>) outs(%subview_3 : memref<1024x32xf16, strided<[1024, 1], offset: ?>>) permutation = [1, 0]
linalg.matmul ins(%subview_1, %subview_3 : memref<32x1024xf16, strided<[1024, 1], offset: ?>>, memref<1024x32xf16, strided<[1024, 1], offset: ?>>) outs(%subview_0 : memref<32x32xf16, strided<[1024, 1], offset: ?>>)
scf.reduce
}
memref.dealloc %alloc : memref<1024x1024xf16>
return
}
}

// CHECK-LABEL: func.func @matmul_transpose_b_sep
// CHECK-SAME: %[[Ap:.+]]: memref<1024x1024xf16>, %[[Bp:.+]]: memref<1024x1024xf16>, %[[Cp:.+]]: memref<1024x1024xf16>

// CHECK-NOT: memref.alloc()

// CHECK: scf.parallel (%[[iter1:.+]], %[[iter2:.+]]) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) {
// CHECK: %[[C:.+]] = memref.subview %[[Cp]][%[[iter1]], %[[iter2]]] {{.*}}
// CHECK: %[[A:.+]] = memref.subview %[[Ap]][%[[iter1]], 0] {{.*}}
// CHECK: %[[B:.+]] = memref.subview %[[Bp]][%[[iter2]], 0] {{.*}}

// CHECK-NOT: linalg.transpose

// Create output initial value load tiles.
// CHECK-DAG: %[[rootC:.+]] = xegpu.create_nd_tdesc %[[C]]
// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [%c0, %c0]
// CHECK-COUNT-7: xegpu.update_nd_offset

// Load initial accumulator values.
// CHECK-DAG: %[[vC:.+]] = xegpu.load_nd %[[tC]]
// CHECK-COUNT-7: xegpu.load_nd

// Extend the type to match DPAS output precision.
// CHECK: %[[vC_f32:.+]] = arith.extf %[[vC]]
// CHECK-COUNT-7: arith.extf

// Create input load tiles.
// CHECK: %[[rootA:.+]] = xegpu.create_nd_tdesc %[[A]]
// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [%c0, %c0]
// CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]]
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0]
// CHECK: %[[tB1:.+]] = xegpu.update_nd_offset %[[rootB]], [%c16, %c0]

// Create DPAS computation loop over tiled reduction dimension.
// CHECK: %[[res:.+]]:11 = scf.for{{.*}}%c0 to %c1024 step %c16
// CHECK-SAME: iter_args(%[[acc:.+]] = %[[vC_f32]],{{.*}}%[[iterA:.+]] = %[[tA]],{{.*}}%[[iterB:.+]] = %[[tB]],{{.*}}%[[iterB1:.+]] = %[[tB1]]
// CHECK-SAME: {

// Load input values and update the load tile position.
// CHECK: %[[vA:.+]] = xegpu.load_nd %[[iterA]]
// CHECK: %[[vB:.+]] = xegpu.load_nd %[[iterB]] {{.*}}transpose = array<i64: 1, 0>{{.*}}transpose_bit_width = 32 : i32{{.*}}
// CHECK: %[[vB1:.+]] = xegpu.load_nd %[[iterB1]] {{.*}}transpose = array<i64: 1, 0>, transpose_bit_width = 32 : i32{{.*}}

// CHECK: %[[new_tA:.+]] = xegpu.update_nd_offset %[[iterA]], [%c0, %c16]
// CHECK: %[[new_tB:.+]] = xegpu.update_nd_offset %[[iterB]], [%c0, %c16]
// CHECK: %[[new_tB1:.+]] = xegpu.update_nd_offset %[[iterB1]], [%c0, %c16]

// Apply simple prefetching scheme - start loading the next set of input
// tiles before computation is started.
// CHECK: xegpu.prefetch_nd %[[new_tA]]
// CHECK: xegpu.prefetch_nd %[[new_tB]]
// CHECK: xegpu.prefetch_nd %[[new_tB1]]

// Extract DPAS-sized chunks from larger loaded tile A.
// Tile B is already in the correct shape.
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16>
// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16>
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16>
// CHECK-COUNT-3: vector.extract_strided_slice

// Perform DPAS computation.
// CHECK: %[[dpas:.+]] = xegpu.dpas %[[vA_dpas]], %[[vB]], %[[acc]]
// CHECK-COUNT-7: xegpu.dpas

// CHECK-NOT: memref.dealloc()
Loading