@@ -597,12 +597,22 @@ static SmallVector<Value> updateTilesOffsets(PatternRewriter &rewriter,
597
597
Location loc, ValueRange tiles,
598
598
ArrayRef<int64_t > offsets) {
599
599
SmallVector<Value> updatedTiles;
600
+ // convert static offsets to dynamic because of this IMEX bug:
601
+ // https://github.com/intel/mlir-extensions/issues/815
602
+ std::vector<Value> dynOffsets;
603
+ for (auto &x : offsets) {
604
+ Value offset = rewriter.create <arith::ConstantIndexOp>(loc, x);
605
+ dynOffsets.push_back (offset);
606
+ }
607
+ ValueRange newOffsets{dynOffsets};
600
608
for (auto tile : tiles) {
601
- auto updatedTile =
602
- rewriter
603
- .create <xegpu::UpdateNdOffsetOp>(loc, tile.getType (), tile,
604
- /* offsets=*/ ValueRange{}, offsets)
605
- .getResult ();
609
+ auto updatedTile = rewriter
610
+ .create <xegpu::UpdateNdOffsetOp>(
611
+ loc, tile.getType (), tile,
612
+ /* offsets=*/ newOffsets,
613
+ SmallVector<int64_t >{ShapedType::kDynamic ,
614
+ ShapedType::kDynamic })
615
+ .getResult ();
606
616
updatedTiles.push_back (updatedTile);
607
617
}
608
618
@@ -648,11 +658,17 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,
648
658
649
659
SmallVector<Value> tiles;
650
660
for (int i = 0 ; i < loadShape[0 ]; i += descTile[0 ]) {
661
+ // convert static offsets to dynamic because of this IMEX bug:
662
+ // https://github.com/intel/mlir-extensions/issues/815
663
+ Value newRowOffs = rewriter.create <arith::ConstantIndexOp>(loc, i);
651
664
for (int j = 0 ; j < loadShape[1 ]; j += descTile[1 ] * arrayLength) {
665
+ Value newColOffs = rewriter.create <arith::ConstantIndexOp>(loc, j);
652
666
auto tile = rewriter
653
667
.create <xegpu::UpdateNdOffsetOp>(
654
668
loc, descType, rootTile,
655
- /* offsets=*/ ValueRange{}, SmallVector<int64_t >{i, j})
669
+ /* offsets=*/ ValueRange{newRowOffs, newColOffs},
670
+ SmallVector<int64_t >{ShapedType::kDynamic ,
671
+ ShapedType::kDynamic })
656
672
.getResult ();
657
673
tiles.push_back (tile);
658
674
}
@@ -732,17 +748,18 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,
732
748
733
749
VectorType vecLoadType =
734
750
VectorType::get (tileType.getShape (), tileType.getElementType ());
735
- UnitAttr vnniAxisAttr = nullptr ;
751
+ mlir:: UnitAttr packedAttr = nullptr ;
736
752
if (vnniConf) {
737
- vnniAxisAttr = UnitAttr::get (rewriter.getContext ());
738
753
vecLoadType = getVnniVector (tileType.getShape (), tileType.getElementType (),
739
754
*vnniConf);
755
+ packedAttr = mlir::UnitAttr::get (rewriter.getContext ());
740
756
}
741
-
757
+ IntegerAttr transpose_bit = nullptr ;
742
758
SmallVector<Value> loadVec;
743
759
for (auto tile : loadTiles) {
760
+
744
761
auto loadOp = rewriter.create <xegpu::LoadNdOp>(
745
- loc, vecLoadType, tile, vnniAxisAttr , transpose, nullptr ,
762
+ loc, vecLoadType, tile, packedAttr , transpose, transpose_bit ,
746
763
/* l1_hint=*/ hint,
747
764
/* l2_hint=*/ hint, /* l3_hint=*/ hint);
748
765
loadVec.push_back (loadOp);
@@ -1057,7 +1074,7 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
1057
1074
1058
1075
// Load A sub-tiles.
1059
1076
SmallVector<Value> loadVecA =
1060
- loadNdDescTiles (rewriter, loc, tilesA, readCacheHint, vnniConfA );
1077
+ loadNdDescTiles (rewriter, loc, tilesA, readCacheHint);
1061
1078
auto tileTypeA = cast<xegpu::TensorDescType>(tilesA[0 ].getType ());
1062
1079
1063
1080
// Load B sub-tiles.
0 commit comments