Skip to content

Commit

Permalink
Fix pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
nbpatel committed Nov 6, 2024
1 parent 6d4aa6a commit a19f60f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
22 changes: 11 additions & 11 deletions lib/Dialect/XeTile/Transforms/WgToSg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,9 @@ class WGToSGVectorTranspose
// 3. Load the vector from slm using the result layout

// Example:
// WG IR
// #wg_map_b = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>
// #wg_map_a = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>
// WG IR
// #wg_map_b = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>
// #wg_map_a = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>
// %vector_a = xetile.tile_conv_layout %vector_b {wg_map_result = #wg_map_a, wg_map_source = #wg_map_b}: vector<256x256xfloat> into vector<256x256xfloat>

// SG IR
Expand All @@ -613,7 +613,7 @@ class WGToSGXeTileConvertLayout
auto elemTy = resType.getElementType();
auto resShape = resType.getShape();

auto dstMapAttr =
auto dstMapAttr =
llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(op->getAttr("wg_map_result"));

xetile::WorkGroupMapAttr srcMapAttr;
Expand All @@ -622,7 +622,7 @@ class WGToSGXeTileConvertLayout
if (!dstMapAttr) {
return mlir::failure();
}

if(!srcMapAttr) {
// Get the map from operand
auto operand = op.getSource().getDefiningOp();
Expand All @@ -642,7 +642,7 @@ class WGToSGXeTileConvertLayout
return rewriter.create<mlir::arith::ConstantOp>(loc, type, attr);
};

rewriter.setInsertionPoint(op);
rewriter.setInsertionPoint(op);
// Allocate SLM
// TODO: Allocate slm as 1D array of i8, and then create the expected view on it.
auto slmTy = mlir::MemRefType::get({resShape[0], resShape[1]}, elemTy, {}, 3);
Expand Down Expand Up @@ -675,25 +675,25 @@ class WGToSGXeTileConvertLayout
nullptr /*scatterAttr*/);
xetile::TileType srcTileTy =
imex::xetile::TileType::get({srcMapSgData[0], srcMapSgData[1]}, elemTy, attr);

auto storeOffsetX = rewriter.createOrFold<mlir::index::MulOp>(
loc, storeSgIdX, createIndexConstant(indexType, srcMapSgData[0]));
auto storeOffsetY = rewriter.createOrFold<mlir::index::MulOp>(
loc, storeSgIdY, createIndexConstant(indexType, srcMapSgData[1]));
loc, storeSgIdY, createIndexConstant(indexType, srcMapSgData[1]));
auto storeInitTileOp = rewriter.create<xetile::InitTileOp>(
loc, srcTileTy, slm, llvm::ArrayRef<mlir::OpFoldResult>({storeOffsetX, storeOffsetY}));
rewriter.create<xetile::StoreTileOp>(loc, adaptor.getSource()[0],
storeInitTileOp);

// Add barrier
rewriter.create<mlir::gpu::BarrierOp>(loc);

// Load from SLM with result map
xetile::TileType dstTileTy =
imex::xetile::TileType::get({dstMapSgData[0], dstMapSgData[1]}, elemTy, attr);
auto newResTy =
mlir::VectorType::get({dstMapSgData[0], dstMapSgData[1]}, elemTy);

auto dstMapDimY = createIndexConstant(indexType, dstSgLayout[1]);
auto loadSgIdX = rewriter.create<mlir::index::FloorDivSOp>(loc, sgId, dstMapDimY);
auto loadSgIdY = rewriter.create<mlir::index::RemUOp>(loc, sgId, dstMapDimY);
Expand All @@ -705,7 +705,7 @@ class WGToSGXeTileConvertLayout
loc, dstTileTy, slm, llvm::ArrayRef<mlir::OpFoldResult>({loadOffsetX, loadOffsetY}));
auto loadTile = rewriter.create<xetile::LoadTileOp>(
loc, newResTy, loadInitTileOp, mlir::Attribute());

rewriter.replaceOp(op, loadTile);
return mlir::success();
}
Expand Down
4 changes: 1 addition & 3 deletions test/Dialect/XeTile/Transforms/wg_to_sg_convert_layout.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,5 @@ gpu.module @test_convert_layout{
%convert_layout = xetile.convert_layout %cst {wg_map_result = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>, wg_map_source = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>} : vector<256x256xf32>
%add = arith.addf %cst_temp, %convert_layout {map = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>} : vector<256x256xf32>
gpu.return
}
}
}


0 comments on commit a19f60f

Please sign in to comment.