Skip to content

Commit

Permalink
[LinalgToXeGPU] Lower linalg.matmul_transpose_b into xegpu.dpas (#…
Browse files Browse the repository at this point in the history
…347)

Signed-off-by: dchigarev <[email protected]>
  • Loading branch information
dchigarev authored Sep 30, 2024
1 parent 199501e commit 1fee896
Show file tree
Hide file tree
Showing 7 changed files with 557 additions and 19 deletions.
142 changes: 123 additions & 19 deletions lib/gc/Transforms/GPU/LinalgToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ static bool isDPASCompatible(linalg::LinalgOp linalgOp, int kTile,
ArrayRef<int64_t> dpasTile) {
if (!(isa<linalg::MatmulOp>(linalgOp) ||
isa<linalg::BatchReduceMatmulOp>(linalgOp) ||
isa<linalg::MatmulTransposeBOp>(linalgOp) ||
isa<linalg::GenericOp>(linalgOp))) {
return false;
}
Expand Down Expand Up @@ -633,12 +634,11 @@ static SmallVector<Value> updateTilesOffsets(PatternRewriter &rewriter,
//
// The descriptor sub-tiles are ordered in row-major fashion with respect to the
// whole load tile.
static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,
Location loc, Value src,
ArrayRef<int64_t> loadShape,
ArrayRef<int64_t> loadOffsets,
ArrayRef<int64_t> descTile,
int arrayLength = 1) {
static SmallVector<Value>
createDescriptorTiles(PatternRewriter &rewriter, Location loc, Value src,
ArrayRef<int64_t> loadShape,
ArrayRef<int64_t> loadOffsets, ArrayRef<int64_t> descTile,
int arrayLength = 1, bool transpose = false) {
assert(arrayLength == 1 && "Array descriptors are not supported");

auto type = cast<ShapedType>(src.getType());
Expand Down Expand Up @@ -669,6 +669,9 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,
Value newRowOffs = rewriter.create<arith::ConstantIndexOp>(loc, i);
for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) {
Value newColOffs = rewriter.create<arith::ConstantIndexOp>(loc, j);
if (transpose) {
std::swap(newRowOffs, newColOffs);
}
auto tile = rewriter
.create<xegpu::UpdateNdOffsetOp>(
loc, descType, rootTile,
Expand All @@ -693,7 +696,8 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,
static SmallVector<Value> createCoarseDscTiles(PatternRewriter &rewriter,
Location loc, Value src,
ArrayRef<int64_t> sgTile,
bool isVnni) {
bool isVnni,
bool transpose = false) {
assert(sgTile.size() <= 2 &&
"Require at most 2D tile size for eltwise lowering");

Expand Down Expand Up @@ -727,7 +731,8 @@ static SmallVector<Value> createCoarseDscTiles(PatternRewriter &rewriter,
// NOLINTEND

return createDescriptorTiles(rewriter, loc, src, sgTile2D, {0, 0},
{sgLoadRows, sgLoadCols}, arrayLength);
{sgLoadRows, sgLoadCols}, arrayLength,
transpose);
}

// Return vector type with specified VNNI shape.
Expand All @@ -745,7 +750,8 @@ static SmallVector<Value>
loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,
xegpu::CachePolicyAttr hint,
std::optional<VnniConfig> vnniConf = std::nullopt,
DenseI64ArrayAttr transpose = nullptr) {
DenseI64ArrayAttr transpose = nullptr,
IntegerAttr transpose_bit = nullptr) {
// Assume all tiles have the same shape.
auto tileType = cast<xegpu::TensorDescType>(loadTiles[0].getType());
assert(llvm::all_of(loadTiles,
Expand All @@ -760,7 +766,6 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,
*vnniConf);
packedAttr = mlir::UnitAttr::get(rewriter.getContext());
}
IntegerAttr transpose_bit = nullptr;
SmallVector<Value> loadVec;
for (auto tile : loadTiles) {

Expand Down Expand Up @@ -860,13 +865,82 @@ extractVecSubTiles(PatternRewriter &rewriter, Location loc,
return subTiles;
}

// Checks whether the given `matmulOperand` is produced by a
// `linalg::TransposeOp` and ensures that the transpose result is only used by
// valid operations, such as `linalg::MatmulOp`, `linalg::BatchReduceMatmulOp`,
// or `linalg::GenericOp`.
//
// If a valid transpose operation is found, the function records it for later
// removal and returns the operand of the transpose operation as the new matrix
// multiplication operand.
static FailureOr<Value> findAndReplaceTranspose(const Value &matmulOperand,
size_t operandIdx,
PatternRewriter &rewriter) {
auto defOp = matmulOperand.getDefiningOp();
if (!defOp) {
return failure();
}
linalg::TransposeOp transposeOp = nullptr;

for (auto x : defOp->getUsers()) {
if (isa<linalg::TransposeOp>(x)) {
if (transposeOp) {
return rewriter.notifyMatchFailure(
transposeOp, "Only one transpose operation is allowed");
}

transposeOp = dyn_cast<linalg::TransposeOp>(x);

auto transposeRes = transposeOp.getDpsInits()[0];
// verify that there are no other users of the transpose result
// rather than our matmul
for (auto trUser : transposeRes.getUsers()) {
if (isa<linalg::MatmulOp>(trUser) ||
isa<linalg::BatchReduceMatmulOp>(trUser) ||
isa<linalg::GenericOp>(trUser)) {
auto matmulOp = dyn_cast<linalg::LinalgOp>(trUser);
auto actualMatmulOperand = matmulOp.getDpsInputs()[operandIdx];
if (actualMatmulOperand != matmulOperand) {
return rewriter.notifyMatchFailure(
trUser,
"Transpose result is used by more than one matmul operation");
}
} else if (isa<memref::DeallocOp>(trUser)) {
// allow deallocs as users
continue;
} else if (isa<linalg::TransposeOp>(trUser)) {
// check if it's the same transpose as we're processing
if (!mlir::OperationEquivalence::isEquivalentTo(trUser, transposeOp,
/*flags=*/nullptr)) {
return rewriter.notifyMatchFailure(
trUser, "Only one transpose operation is allowed");
}
continue;
} else {
return rewriter.notifyMatchFailure(
trUser,
"Transpose result is not allowed to be used by this operation");
}
}
}
}
if (transposeOp) {
auto ret = transposeOp.getDpsInputs()[0];
rewriter.eraseOp(transposeOp);
return ret;
}
return rewriter.notifyMatchFailure(
defOp, "No transpose operation producing the operand was found");
}

// Create XeGPU DPAS kernel out of GEMM-like operation.
static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
ArrayRef<int64_t> dpasTile, int kTile,
int prefetchStages,
PatternRewriter &rewriter) {
assert((isa<linalg::MatmulOp>(linalgOp) ||
isa<linalg::BatchReduceMatmulOp>(linalgOp) ||
isa<linalg::MatmulTransposeBOp>(linalgOp) ||
isa<linalg::GenericOp>(linalgOp)) &&
"Requires a GEMM-like op for DPAS lowering");

Expand All @@ -877,6 +951,17 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
auto matB = linalgOp.getDpsInputs()[1];
auto matC = linalgOp.getDpsInits()[0];

bool transposeB = false;
if (isa<linalg::MatmulTransposeBOp>(linalgOp)) {
transposeB = true;
} else {
auto newMatB = findAndReplaceTranspose(matB, /*operandIdx=*/1, rewriter);
if (!failed(newMatB)) {
matB = *newMatB;
transposeB = true;
}
}

auto typeA = cast<ShapedType>(matA.getType());
auto typeC = cast<ShapedType>(matC.getType());

Expand Down Expand Up @@ -961,7 +1046,8 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,

// Create B sub-tiles.
SmallVector<Value> tilesB =
createCoarseDscTiles(rewriter, loc, matB, {kTile, dimN}, /*isVnni=*/true);
createCoarseDscTiles(rewriter, loc, matB, {kTile, dimN},
/*isVnni=*/true, transposeB);

// Create input prefetch tiles.
int64_t numThreads = 1;
Expand Down Expand Up @@ -997,7 +1083,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
{dimM, dimN}, kTile);
auto prefetchDescB = createGemmCoopPrefetchTile(
rewriter, linalgOp, /*inputPos=*/1, numThreads, {blockRows, blockCols},
{dimM, dimN}, kTile);
(transposeB) ? std::vector<int32_t>{dimM, dimN}
: std::vector<int32_t>{dimN, dimM},
kTile);

if (succeeded(prefetchDescA) && succeeded(prefetchDescB)) {
prefetchA = prefetchDescA->getResult();
Expand All @@ -1012,7 +1100,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
prefetchA = updateTilesOffsets(rewriter, loc, ValueRange{prefetchA},
{0, kTile})[0];
prefetchB = updateTilesOffsets(rewriter, loc, ValueRange{prefetchB},
{kTile, 0})[0];
(transposeB)
? std::vector<int64_t>{0, kTile}
: std::vector<int64_t>{kTile, 0})[0];
}
} else {
// Disable coop prefetching on failure.
Expand Down Expand Up @@ -1083,15 +1173,26 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint);
auto tileTypeA = cast<xegpu::TensorDescType>(tilesA[0].getType());

DenseI64ArrayAttr transpose = nullptr;
IntegerAttr transpose_bit = nullptr;

if (transposeB) {
transpose_bit = rewriter.getIntegerAttr(rewriter.getIntegerType(32), 32);
transpose = DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0});
}

// Load B sub-tiles.
SmallVector<Value> loadVecB =
loadNdDescTiles(rewriter, loc, tilesB, readCacheHint, vnniConfB);
loadNdDescTiles(rewriter, loc, tilesB, readCacheHint, vnniConfB,
transpose, transpose_bit);
auto tileTypeB = cast<xegpu::TensorDescType>(tilesB[0].getType());

// Update offsets of the input tiles.
// Shift along the reduction dimension.
tilesA = updateTilesOffsets(rewriter, loc, tilesA, {0, kTile});
tilesB = updateTilesOffsets(rewriter, loc, tilesB, {kTile, 0});
tilesB = updateTilesOffsets(rewriter, loc, tilesB,
transposeB ? std::vector<int64_t>{0, kTile}
: std::vector<int64_t>{kTile, 0});

// Prefetch the next set of input tiles.
if (isCoopPrefetch) {
Expand All @@ -1101,7 +1202,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
prefetchA =
updateTilesOffsets(rewriter, loc, ValueRange{prefetchA}, {0, kTile})[0];
prefetchB =
updateTilesOffsets(rewriter, loc, ValueRange{prefetchB}, {kTile, 0})[0];
updateTilesOffsets(rewriter, loc, ValueRange{prefetchB},
transposeB ? std::vector<int64_t>{0, kTile}
: std::vector<int64_t>{kTile, 0})[0];
} else {
// Apply naive prefetching for each subgroup separately.
prefetchTiles(rewriter, loc, tilesA, readCacheHint);
Expand Down Expand Up @@ -1288,7 +1391,7 @@ struct ConvertGemmLikeToXeGPU : public OpRewritePattern<LinalgOpTy> {
// Constrain conversion to the supported GEMM-like ops.
static_assert(
llvm::is_one_of<LinalgOpTy, linalg::MatmulOp, linalg::BatchReduceMatmulOp,
linalg::GenericOp>::value);
linalg::GenericOp, linalg::MatmulTransposeBOp>::value);

ConvertGemmLikeToXeGPU(MLIRContext *ctx, LinalgToXeGPUOptions options)
: OpRewritePattern<LinalgOpTy>(ctx), options(options) {}
Expand Down Expand Up @@ -1495,8 +1598,9 @@ struct ConvertMemoryFillToXeGPU : public OpRewritePattern<LinalgOpTy> {
void populateLinalgGemmToXeGPUPatterns(RewritePatternSet &patterns,
LinalgToXeGPUOptions options) {
patterns.add<ConvertGemmLikeToXeGPU<linalg::MatmulOp>,
ConvertGemmLikeToXeGPU<linalg::GenericOp>>(patterns.getContext(),
options);
ConvertGemmLikeToXeGPU<linalg::GenericOp>,
ConvertGemmLikeToXeGPU<linalg::MatmulTransposeBOp>>(
patterns.getContext(), options);
}

void populateLinalgEltwiseToXeGPUPatterns(RewritePatternSet &patterns,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file | FileCheck %s

module {
func.func @matmul_transpose_b_sep(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf16>) {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c1024 = arith.constant 1024 : index
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1024x1024xf16>
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) {
%subview_0 = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<1024x1024xf16> to memref<32x32xf16, strided<[1024, 1], offset: ?>>
%subview_1 = memref.subview %arg0[%arg3, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>>
%subview_2 = memref.subview %arg1[%arg4, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>>
%subview_3 = memref.subview %alloc[0, %arg4] [1024, 32] [1, 1] : memref<1024x1024xf16> to memref<1024x32xf16, strided<[1024, 1], offset: ?>>
linalg.transpose ins(%subview_2 : memref<32x1024xf16, strided<[1024, 1], offset: ?>>) outs(%subview_3 : memref<1024x32xf16, strided<[1024, 1], offset: ?>>) permutation = [1, 0]
linalg.matmul ins(%subview_1, %subview_3 : memref<32x1024xf16, strided<[1024, 1], offset: ?>>, memref<1024x32xf16, strided<[1024, 1], offset: ?>>) outs(%subview_0 : memref<32x32xf16, strided<[1024, 1], offset: ?>>)
scf.reduce
}
memref.dealloc %alloc : memref<1024x1024xf16>
return
}
}

// CHECK-LABEL: func.func @matmul_transpose_b_sep
// CHECK-SAME: %[[Ap:.+]]: memref<1024x1024xf16>, %[[Bp:.+]]: memref<1024x1024xf16>, %[[Cp:.+]]: memref<1024x1024xf16>

// CHECK-NOT: memref.alloc()

// CHECK: scf.parallel (%[[iter1:.+]], %[[iter2:.+]]) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) {
// CHECK: %[[C:.+]] = memref.subview %[[Cp]][%[[iter1]], %[[iter2]]] {{.*}}
// CHECK: %[[A:.+]] = memref.subview %[[Ap]][%[[iter1]], 0] {{.*}}
// CHECK: %[[B:.+]] = memref.subview %[[Bp]][%[[iter2]], 0] {{.*}}

// CHECK-NOT: linalg.transpose

// Create output initial value load tiles.
// CHECK-DAG: %[[rootC:.+]] = xegpu.create_nd_tdesc %[[C]]
// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [%c0, %c0]
// CHECK-COUNT-7: xegpu.update_nd_offset

// Load initial accumulator values.
// CHECK-DAG: %[[vC:.+]] = xegpu.load_nd %[[tC]]
// CHECK-COUNT-7: xegpu.load_nd

// Extend the type to match DPAS output precision.
// CHECK: %[[vC_f32:.+]] = arith.extf %[[vC]]
// CHECK-COUNT-7: arith.extf

// Create input load tiles.
// CHECK: %[[rootA:.+]] = xegpu.create_nd_tdesc %[[A]]
// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [%c0, %c0]
// CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]]
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0]
// CHECK: %[[tB1:.+]] = xegpu.update_nd_offset %[[rootB]], [%c16, %c0]

// Create DPAS computation loop over tiled reduction dimension.
// CHECK: %[[res:.+]]:11 = scf.for{{.*}}%c0 to %c1024 step %c16
// CHECK-SAME: iter_args(%[[acc:.+]] = %[[vC_f32]],{{.*}}%[[iterA:.+]] = %[[tA]],{{.*}}%[[iterB:.+]] = %[[tB]],{{.*}}%[[iterB1:.+]] = %[[tB1]]
// CHECK-SAME: {

// Load input values and update the load tile position.
// CHECK: %[[vA:.+]] = xegpu.load_nd %[[iterA]]
// CHECK: %[[vB:.+]] = xegpu.load_nd %[[iterB]] {{.*}}transpose = array<i64: 1, 0>{{.*}}transpose_bit_width = 32 : i32{{.*}}
// CHECK: %[[vB1:.+]] = xegpu.load_nd %[[iterB1]] {{.*}}transpose = array<i64: 1, 0>, transpose_bit_width = 32 : i32{{.*}}

// CHECK: %[[new_tA:.+]] = xegpu.update_nd_offset %[[iterA]], [%c0, %c16]
// CHECK: %[[new_tB:.+]] = xegpu.update_nd_offset %[[iterB]], [%c0, %c16]
// CHECK: %[[new_tB1:.+]] = xegpu.update_nd_offset %[[iterB1]], [%c0, %c16]

// Apply simple prefetching scheme - start loading the next set of input
// tiles before computation is started.
// CHECK: xegpu.prefetch_nd %[[new_tA]]
// CHECK: xegpu.prefetch_nd %[[new_tB]]
// CHECK: xegpu.prefetch_nd %[[new_tB1]]

// Extract DPAS-sized chunks from larger loaded tile A.
// Tile B is already in the correct shape.
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16>
// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16>
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16>
// CHECK-COUNT-3: vector.extract_strided_slice

// Perform DPAS computation.
// CHECK: %[[dpas:.+]] = xegpu.dpas %[[vA_dpas]], %[[vB]], %[[acc]]
// CHECK-COUNT-7: xegpu.dpas

// CHECK-NOT: memref.dealloc()
Loading

0 comments on commit 1fee896

Please sign in to comment.