@@ -585,9 +585,9 @@ class WGToSGVectorTranspose
585
585
// 3. Load the vector from slm using the result layout
586
586
587
587
// Example:
588
- // WG IR
589
- // #wg_map_b = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>
590
- // #wg_map_a = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>
588
+ // WG IR
589
+ // #wg_map_b = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>
590
+ // #wg_map_a = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>
591
591
// %vector_a = xetile.tile_conv_layout %vector_b {wg_map_result = #wg_map_a, wg_map_source = #wg_map_b}: vector<256x256xfloat> into vector<256x256xfloat>
592
592
593
593
// SG IR
@@ -613,7 +613,7 @@ class WGToSGXeTileConvertLayout
613
613
auto elemTy = resType.getElementType ();
614
614
auto resShape = resType.getShape ();
615
615
616
- auto dstMapAttr =
616
+ auto dstMapAttr =
617
617
llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(op->getAttr (" wg_map_result" ));
618
618
619
619
xetile::WorkGroupMapAttr srcMapAttr;
@@ -622,7 +622,7 @@ class WGToSGXeTileConvertLayout
622
622
if (!dstMapAttr) {
623
623
return mlir::failure ();
624
624
}
625
-
625
+
626
626
if (!srcMapAttr) {
627
627
// Get the map from operand
628
628
auto operand = op.getSource ().getDefiningOp ();
@@ -642,7 +642,7 @@ class WGToSGXeTileConvertLayout
642
642
return rewriter.create <mlir::arith::ConstantOp>(loc, type, attr);
643
643
};
644
644
645
- rewriter.setInsertionPoint (op);
645
+ rewriter.setInsertionPoint (op);
646
646
// Allocate SLM
647
647
// TODO: Allocate slm as 1D array of i8, and then create the expected view on it.
648
648
auto slmTy = mlir::MemRefType::get ({resShape[0 ], resShape[1 ]}, elemTy, {}, 3 );
@@ -675,25 +675,25 @@ class WGToSGXeTileConvertLayout
675
675
nullptr /* scatterAttr*/ );
676
676
xetile::TileType srcTileTy =
677
677
imex::xetile::TileType::get ({srcMapSgData[0 ], srcMapSgData[1 ]}, elemTy, attr);
678
-
678
+
679
679
auto storeOffsetX = rewriter.createOrFold <mlir::index::MulOp>(
680
680
loc, storeSgIdX, createIndexConstant (indexType, srcMapSgData[0 ]));
681
681
auto storeOffsetY = rewriter.createOrFold <mlir::index::MulOp>(
682
- loc, storeSgIdY, createIndexConstant (indexType, srcMapSgData[1 ]));
682
+ loc, storeSgIdY, createIndexConstant (indexType, srcMapSgData[1 ]));
683
683
auto storeInitTileOp = rewriter.create <xetile::InitTileOp>(
684
684
loc, srcTileTy, slm, llvm::ArrayRef<mlir::OpFoldResult>({storeOffsetX, storeOffsetY}));
685
685
rewriter.create <xetile::StoreTileOp>(loc, adaptor.getSource ()[0 ],
686
686
storeInitTileOp);
687
687
688
688
// Add barrier
689
689
rewriter.create <mlir::gpu::BarrierOp>(loc);
690
-
690
+
691
691
// Load from SLM with result map
692
692
xetile::TileType dstTileTy =
693
693
imex::xetile::TileType::get ({dstMapSgData[0 ], dstMapSgData[1 ]}, elemTy, attr);
694
694
auto newResTy =
695
695
mlir::VectorType::get ({dstMapSgData[0 ], dstMapSgData[1 ]}, elemTy);
696
-
696
+
697
697
auto dstMapDimY = createIndexConstant (indexType, dstSgLayout[1 ]);
698
698
auto loadSgIdX = rewriter.create <mlir::index::FloorDivSOp>(loc, sgId, dstMapDimY);
699
699
auto loadSgIdY = rewriter.create <mlir::index::RemUOp>(loc, sgId, dstMapDimY);
@@ -705,7 +705,7 @@ class WGToSGXeTileConvertLayout
705
705
loc, dstTileTy, slm, llvm::ArrayRef<mlir::OpFoldResult>({loadOffsetX, loadOffsetY}));
706
706
auto loadTile = rewriter.create <xetile::LoadTileOp>(
707
707
loc, newResTy, loadInitTileOp, mlir::Attribute ());
708
-
708
+
709
709
rewriter.replaceOp (op, loadTile);
710
710
return mlir::success ();
711
711
}
0 commit comments