Skip to content

Commit 964398e

Browse files
committed
Aling 'linalg-to-xegpu' pass with patched XeGPU dialect
Signed-off-by: dchigarev <[email protected]>
1 parent 1b9d6ac commit 964398e

File tree

1 file changed

+28
-11
lines changed

1 file changed

+28
-11
lines changed

lib/gc/Transforms/GPU/LinalgToXeGPU.cpp

+28-11
Original file line numberDiff line numberDiff line change
@@ -597,12 +597,22 @@ static SmallVector<Value> updateTilesOffsets(PatternRewriter &rewriter,
597597
Location loc, ValueRange tiles,
598598
ArrayRef<int64_t> offsets) {
599599
SmallVector<Value> updatedTiles;
600+
// convert static offsets to dynamic because of this IMEX bug:
601+
// https://github.com/intel/mlir-extensions/issues/815
602+
std::vector<Value> dynOffsets;
603+
for (auto &x : offsets) {
604+
Value offset = rewriter.create<arith::ConstantIndexOp>(loc, x);
605+
dynOffsets.push_back(offset);
606+
}
607+
ValueRange newOffsets{dynOffsets};
600608
for (auto tile : tiles) {
601-
auto updatedTile =
602-
rewriter
603-
.create<xegpu::UpdateNdOffsetOp>(loc, tile.getType(), tile,
604-
/*offsets=*/ValueRange{}, offsets)
605-
.getResult();
609+
auto updatedTile = rewriter
610+
.create<xegpu::UpdateNdOffsetOp>(
611+
loc, tile.getType(), tile,
612+
/*offsets=*/newOffsets,
613+
SmallVector<int64_t>{ShapedType::kDynamic,
614+
ShapedType::kDynamic})
615+
.getResult();
606616
updatedTiles.push_back(updatedTile);
607617
}
608618

@@ -648,11 +658,17 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,
648658

649659
SmallVector<Value> tiles;
650660
for (int i = 0; i < loadShape[0]; i += descTile[0]) {
661+
// convert static offsets to dynamic because of this IMEX bug:
662+
// https://github.com/intel/mlir-extensions/issues/815
663+
Value newRowOffs = rewriter.create<arith::ConstantIndexOp>(loc, i);
651664
for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) {
665+
Value newColOffs = rewriter.create<arith::ConstantIndexOp>(loc, j);
652666
auto tile = rewriter
653667
.create<xegpu::UpdateNdOffsetOp>(
654668
loc, descType, rootTile,
655-
/*offsets=*/ValueRange{}, SmallVector<int64_t>{i, j})
669+
/*offsets=*/ValueRange{newRowOffs, newColOffs},
670+
SmallVector<int64_t>{ShapedType::kDynamic,
671+
ShapedType::kDynamic})
656672
.getResult();
657673
tiles.push_back(tile);
658674
}
@@ -732,17 +748,18 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,
732748

733749
VectorType vecLoadType =
734750
VectorType::get(tileType.getShape(), tileType.getElementType());
735-
UnitAttr vnniAxisAttr = nullptr;
751+
mlir::UnitAttr packedAttr = nullptr;
736752
if (vnniConf) {
737-
vnniAxisAttr = UnitAttr::get(rewriter.getContext());
738753
vecLoadType = getVnniVector(tileType.getShape(), tileType.getElementType(),
739754
*vnniConf);
755+
packedAttr = mlir::UnitAttr::get(rewriter.getContext());
740756
}
741-
757+
IntegerAttr transpose_bit = nullptr;
742758
SmallVector<Value> loadVec;
743759
for (auto tile : loadTiles) {
760+
744761
auto loadOp = rewriter.create<xegpu::LoadNdOp>(
745-
loc, vecLoadType, tile, vnniAxisAttr, transpose, nullptr,
762+
loc, vecLoadType, tile, packedAttr, transpose, transpose_bit,
746763
/*l1_hint=*/hint,
747764
/*l2_hint=*/hint, /*l3_hint=*/hint);
748765
loadVec.push_back(loadOp);
@@ -1057,7 +1074,7 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
10571074

10581075
// Load A sub-tiles.
10591076
SmallVector<Value> loadVecA =
1060-
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint, vnniConfA);
1077+
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint);
10611078
auto tileTypeA = cast<xegpu::TensorDescType>(tilesA[0].getType());
10621079

10631080
// Load B sub-tiles.

0 commit comments

Comments
 (0)