Skip to content

Commit

Permalink
Add support for reduction in wg to sg transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
nbpatel committed Dec 24, 2024
1 parent 5b2a98e commit 6f250a4
Show file tree
Hide file tree
Showing 11 changed files with 515 additions and 42 deletions.
238 changes: 207 additions & 31 deletions lib/Dialect/XeTile/Transforms/WgToSg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class WGToSGInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
// row = i / cols
// col = i % cols
auto sgIdY =
rewriter.create<mlir::index::FloorDivSOp>(loc, sgID, sgLayoutDimYConst);
rewriter.create<mlir::index::DivUOp>(loc, sgID, sgLayoutDimYConst);
auto sgIdX =
rewriter.create<mlir::index::RemUOp>(loc, sgID, sgLayoutDimYConst);

Expand Down Expand Up @@ -496,7 +496,7 @@ class WGToSGArithConstantOpPattern
auto valueType = mlir::dyn_cast<mlir::VectorType>(value.getType());
auto wgTileShape = valueType.getShape();

if (!value || value.getType().getRank() != 2)
if (!value)
return mlir::failure();

auto mapAttr =
Expand All @@ -507,8 +507,21 @@ class WGToSGArithConstantOpPattern

auto sgData = mapAttr.getSgData();
auto sgLayout = mapAttr.getSgLayout();
mlir::SmallVector<int64_t> outputShape;
// If WG tile rank is 1, set the output shape as the
// non-unit dim of sgData
if(wgTileShape.size() == 1) {
if(sgData[0] == 1)
outputShape.push_back(sgData[1]);
else
outputShape.push_back(sgData[0]);
} else {
outputShape.push_back(sgData[0]);
outputShape.push_back(sgData[1]);
}

auto newTy =
mlir::VectorType::get({sgData[0], sgData[1]}, value.getElementType());
mlir::VectorType::get(outputShape, value.getElementType());

llvm::SmallVector<mlir::Attribute> elems(
value.value_begin<mlir::Attribute>(),
Expand All @@ -522,12 +535,20 @@ class WGToSGArithConstantOpPattern
auto attr = mlir::DenseElementsAttr::get(newTy, newValues);

size_t numOps;
if (sgLayout[0] * sgData[0] == wgTileShape[0] &&
sgLayout[1] * sgData[1] == wgTileShape[1])
numOps = 1; // 1:1 mapping
else
numOps = (wgTileShape[0] / (sgLayout[0] * sgData[0])) +
(wgTileShape[1] / (sgLayout[1] * sgData[1]));
// If WG tile is 1D vector just support 1:1 mapping.
// TODO: Support round robin for 1D
if(wgTileShape.size() == 1) {
if (sgLayout[0] * sgData[0] == wgTileShape[0] ||
sgLayout[1] * sgData[1] == wgTileShape[0])
numOps = 1;
else
return mlir::failure();
} else if(sgLayout[0] * sgData[0] == wgTileShape[0] &&
sgLayout[1] * sgData[1] == wgTileShape[1]) {
numOps = 1;
} else
numOps = (wgTileShape[0] / (sgLayout[0] * sgData[0])) +
(wgTileShape[1] / (sgLayout[1] * sgData[1]));

llvm::SmallVector<::mlir::Value> newOps;
llvm::SmallVector<mlir::Type> newResultTypes;
Expand Down Expand Up @@ -706,9 +727,16 @@ class WGToSGXeTileConvertLayout

rewriter.setInsertionPoint(op);
// Allocate SLM
// TODO: Allocate slm as 1D array of i8, and then create the expected view on it.
auto slmTy = mlir::MemRefType::get({resShape[0], resShape[1]}, elemTy, {}, 3);
auto bitWidth = elemTy.getIntOrFloatBitWidth();
auto flattenFactor = bitWidth / 8;
auto slmShape = resShape[0] * resShape[1] * flattenFactor;
auto slmTy = mlir::MemRefType::get(slmShape, rewriter.getI8Type(), {}, 3);
auto slm = rewriter.create<mlir::memref::AllocOp>(loc, slmTy);
ValueRange sizes;
auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
auto viewTy = mlir::MemRefType::get({resShape[0], resShape[1]}, elemTy, {}, 3);
auto viewOp = rewriter.create<mlir::memref::ViewOp>(
op.getLoc(), viewTy, slm, zero, sizes);

// Get SG id
auto sgId = rewriter.create<mlir::gpu::SubgroupIdOp>(
Expand All @@ -724,7 +752,7 @@ class WGToSGXeTileConvertLayout
// x is row, y is col
// TODO: Floorsdiv and Remu are expensive. Find alterate.
auto storeSgIdX =
rewriter.create<mlir::index::FloorDivSOp>(loc, sgId, srcMapDimY);
rewriter.create<mlir::index::DivUOp>(loc, sgId, srcMapDimY);
auto storeSgIdY =
rewriter.create<mlir::index::RemUOp>(loc, sgId, srcMapDimY);

Expand All @@ -742,7 +770,7 @@ class WGToSGXeTileConvertLayout
auto storeOffsetY = rewriter.createOrFold<mlir::index::MulOp>(
loc, storeSgIdY, createIndexConstant(indexType, srcMapSgData[1]));
auto storeInitTileOp = rewriter.create<xetile::InitTileOp>(
loc, srcTileTy, slm, llvm::ArrayRef<mlir::OpFoldResult>({storeOffsetX, storeOffsetY}));
loc, srcTileTy, viewOp, llvm::ArrayRef<mlir::OpFoldResult>({storeOffsetX, storeOffsetY}));
//TODO: Set up cache attributes
rewriter.create<xetile::StoreTileOp>(loc, adaptor.getSource()[0],
storeInitTileOp, nullptr, nullptr, nullptr);
Expand All @@ -757,14 +785,14 @@ class WGToSGXeTileConvertLayout
mlir::VectorType::get({dstMapSgData[0], dstMapSgData[1]}, elemTy);

auto dstMapDimY = createIndexConstant(indexType, dstSgLayout[1]);
auto loadSgIdX = rewriter.create<mlir::index::FloorDivSOp>(loc, sgId, dstMapDimY);
auto loadSgIdX = rewriter.create<mlir::index::DivUOp>(loc, sgId, dstMapDimY);
auto loadSgIdY = rewriter.create<mlir::index::RemUOp>(loc, sgId, dstMapDimY);
auto loadOffsetX = rewriter.createOrFold<mlir::index::MulOp>(
loc, loadSgIdX, createIndexConstant(indexType, dstMapSgData[0]));
auto loadOffsetY = rewriter.createOrFold<mlir::index::MulOp>(
loc, loadSgIdY, createIndexConstant(indexType, dstMapSgData[1]));
auto loadInitTileOp = rewriter.create<xetile::InitTileOp>(
loc, dstTileTy, slm, llvm::ArrayRef<mlir::OpFoldResult>({loadOffsetX, loadOffsetY}));
loc, dstTileTy, viewOp, llvm::ArrayRef<mlir::OpFoldResult>({loadOffsetX, loadOffsetY}));
//TODO: Set up cache attributes
auto loadTile = rewriter.create<xetile::LoadTileOp>(
loc, newResTy, loadInitTileOp, mlir::Attribute(), nullptr, nullptr, nullptr);
Expand Down Expand Up @@ -834,6 +862,127 @@ class WGToSGPrefetchOpPattern : public XeOneToNConversion<xetile::PrefetchTileOp
}
};

class WGToSGVectorMultiDimReductionOp
: public XeOneToNConversion<mlir::vector::MultiDimReductionOp> {
using XeOneToNConversion<
mlir::vector::MultiDimReductionOp>::XeOneToNConversion;

mlir::LogicalResult
matchAndRewrite(mlir::vector::MultiDimReductionOp op, OpAdaptor adaptor,
XeOneToNPatternRewriter &rewriter) const override {

auto res = op.getResult();
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());
auto resRank = resType.getShape().size();

auto mapAttr =
llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(op->getAttr("map"));

if (!mapAttr) {
return mlir::failure();
}

auto sgData = mapAttr.getSgData();

auto src = adaptor.getSource()[0];
auto srcType = mlir::dyn_cast<mlir::VectorType>(src.getType());

if (resRank == 2) {
bool newReduceDim = sgData[0] == 1 ? 0 : 1;
mlir::SmallVector<int64_t> redDims{newReduceDim};
auto outputShape =
newReduceDim == 0 ? srcType.getDimSize(1) : srcType.getDimSize(0);
auto newTy = mlir::VectorType::get(outputShape, srcType.getElementType());

// ShapeCast acc to match reduction op shape.
auto acc = rewriter.create<vector::ShapeCastOp>(op->getLoc(), newTy,
adaptor.getAcc()[0]);

auto newOp = rewriter.create<mlir::vector::MultiDimReductionOp>(
op.getLoc(), newTy, op.getKind(), src, acc, redDims);

// Shape Cast the output of reduction back to 2D
auto accumalator = adaptor.getAcc()[0];
auto accumalatorType =
mlir::dyn_cast<mlir::VectorType>(accumalator.getType());
auto outputVectorTy = mlir::VectorType::get(
accumalatorType.getShape(), accumalatorType.getElementType());
auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(
op.getLoc(), outputVectorTy, newOp);
rewriter.replaceOp(op, shapeCastOp);
return mlir::success();
}
// Regular 2D vector.multi_reduction
else {
auto reductionDims = op.getReductionDims();
if (reductionDims.size() != 1)
return mlir::failure();

bool reduceDim = reductionDims[0];
auto outputShape =
reduceDim == 0 ? srcType.getDimSize(1) : srcType.getDimSize(0);

mlir::SmallVector<int64_t> redDims{reduceDim};
auto newTy = mlir::VectorType::get(outputShape, srcType.getElementType());
auto newOp = rewriter.create<mlir::vector::MultiDimReductionOp>(
op.getLoc(), newTy, op.getKind(), adaptor.getSource()[0],
adaptor.getAcc()[0], redDims);
rewriter.replaceOp(op, newOp);
return mlir::success();
}
}
};

// Shape cast will support going from 1D to 2D since the vector.multi_reduction
// produces 1D

class WGToSGVectorShapeCast
: public XeOneToNConversion<mlir::vector::ShapeCastOp> {
using XeOneToNConversion<mlir::vector::ShapeCastOp>::XeOneToNConversion;

mlir::LogicalResult
matchAndRewrite(mlir::vector::ShapeCastOp op, OpAdaptor adaptor,
XeOneToNPatternRewriter &rewriter) const override {

auto res = op.getResult();
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());
auto resShape = resType.getShape();

// Assumption is 3D shape cast is used for partial reduction.
// So just replace it with the transformed source of shape_cast
if (resShape.size() == 3) {
for (mlir::Operation *userOp : op.getResult().getUsers()) {
// Check if the user operation is not a vector.multi_reduction
if (!isa<mlir::vector::MultiDimReductionOp>(userOp)) {
return mlir::failure();
}
}
rewriter.replaceOp(op, adaptor.getSource()[0]);
return mlir::success();
}

// One of the dims have to be a unit dim
if (resShape[0] != 1 && resShape[1] != 1)
return mlir::failure();

auto mapAttr =
llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(op->getAttr("map"));

if (!mapAttr) {
return mlir::failure();
}

auto sgData = mapAttr.getSgData();
auto newTy =
mlir::VectorType::get({sgData[0], sgData[1]}, resType.getElementType());

auto newOp = rewriter.create<mlir::vector::ShapeCastOp>(
op.getLoc(), newTy, adaptor.getSource()[0]);
rewriter.replaceOp(op, newOp);
return mlir::success();
}
};

// Helper function to analyze the def-use chain of initTileOps. Currently we
// pattern match the following def-use chain as a candidate for
// load + tranpose optimization.
Expand All @@ -852,10 +1001,11 @@ void analyzeInitTileOps(mlir::Operation *op) {
if (!initOp->hasOneUse())
return mlir::WalkResult::skip();
ops.push_back(initOp);
auto user = *initOp->user_begin();
auto initOpUser = *initOp->user_begin();
// InitTileOp must be consumed by a ForOp
mlir::Operation *loadUser = nullptr, *updateOffsetUser = nullptr;
if (auto scfFor = llvm::dyn_cast_if_present<mlir::scf::ForOp>(user)) {
mlir::Operation *loadUser = nullptr;
mlir::BlockArgument loopArg;
if (auto scfFor = llvm::dyn_cast_if_present<mlir::scf::ForOp>(initOpUser)) {
auto argument = imex::getArgForOperand(scfFor, initOp.getResult());
int userCount = 0;
for (auto user : argument.getUsers()) {
Expand All @@ -865,13 +1015,37 @@ void analyzeInitTileOps(mlir::Operation *op) {
ops.push_back(scfFor);
ops.push_back(user);
} else if (llvm::isa<imex::xetile::UpdateTileOffsetOp>(user)) {
updateOffsetUser = user;
ops.push_back(scfFor);
ops.push_back(user);
}
// Nested scf.for's
// init_tile -> scf.for -> update_tile_offset
// |
// scf.for -> load_tile -> vector.transpose -> (pre-op) ->
// tile_mma
else if (auto scfFor =
llvm::dyn_cast_if_present<mlir::scf::ForOp>(user)) {
for (auto iterOperand : llvm::enumerate(scfFor.getInitArgs())) {
if (iterOperand.value() == argument) {
loopArg = scfFor.getRegionIterArgs()[iterOperand.index()];
break;
}
}

for (auto scfForUser : loopArg.getUsers()) {
if (llvm::isa<imex::xetile::LoadTileOp>(scfForUser)) {
loadUser = scfForUser;
ops.push_back(scfFor);
ops.push_back(scfForUser);
} else if (llvm::isa<imex::xetile::UpdateTileOffsetOp>(
scfForUser)) {
ops.push_back(scfFor);
ops.push_back(scfForUser);
}
}
}
}
// ForOp argument should have only two users, a load and an update offset
if (userCount != 2 || !(loadUser && updateOffsetUser))
if (!loadUser)
return mlir::WalkResult::skip();
} else
return mlir::WalkResult::skip();
Expand All @@ -888,16 +1062,15 @@ void analyzeInitTileOps(mlir::Operation *op) {

// Check if vector.transpose is consumed by TileMMA directly or
// is consumed by some pre-op and then TileMMA.
if(!llvm::isa<imex::xetile::TileMMAOp>(consumerOp)){
if(!OpTrait::hasElementwiseMappableTraits(consumerOp) &&
!(llvm::isa<mlir::vector::BroadcastOp>(consumerOp))) {
if (!llvm::isa<imex::xetile::TileMMAOp>(consumerOp)) {
if (!OpTrait::hasElementwiseMappableTraits(consumerOp) &&
!(llvm::isa<mlir::vector::BroadcastOp>(consumerOp))) {
return mlir::WalkResult::skip();
}
else {
} else {
if (!(consumerOp->hasOneUse() &&
llvm::isa<imex::xetile::TileMMAOp>(*consumerOp->user_begin())))
return mlir::WalkResult::skip();
}
return mlir::WalkResult::skip();
}
}

// At this point, we have a candidate def-use chain for optimization.
Expand All @@ -917,8 +1090,10 @@ void populateXeTileWgToSgPatterns(imex::XeOneToNTypeConverter &converter,
WGToSGSCFForOpPattern, WGToSGUpdateTileOffsetOpPattern,
WGToSGSCFYieldOpPattern, WGToSGVectorTranspose, WGToSGVectorBroadcast,
WGToSGXeTileConvertLayout, WGToSGPrefetchOpPattern, WGToSGArithExtFOpPattern,
WGToSGArithTruncFOpPattern>(patterns.getContext(), converter);
WGToSGArithTruncFOpPattern, WGToSGVectorShapeCast, WGToSGVectorMultiDimReductionOp
>(patterns.getContext(), converter);
patterns.insert<WGToSGElementWiseOpPattern<mlir::math::ExpOp, 1>,
WGToSGElementWiseOpPattern<mlir::math::SqrtOp, 1>,
WGToSGElementWiseOpPattern<mlir::arith::AddFOp, 2>,
WGToSGArithConstantOpPattern>(patterns.getContext(), converter);
}
Expand Down Expand Up @@ -1016,9 +1191,10 @@ class XeTileWgToSgPass
});

target.addDynamicallyLegalOp<mlir::arith::ConstantOp, mlir::arith::AddFOp,
mlir::math::ExpOp, mlir::arith::ExtFOp,
mlir::math::ExpOp, mlir::math::SqrtOp, mlir::arith::ExtFOp,
mlir::arith::TruncFOp, mlir::vector::TransposeOp,
mlir::vector::BroadcastOp>(
mlir::vector::BroadcastOp, mlir::vector::MultiDimReductionOp,
mlir::vector::ShapeCastOp>(
[&](mlir::Operation *op) -> bool {
auto mapAttr = llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(
op->getAttr("map"));
Expand Down
2 changes: 1 addition & 1 deletion test/Dialect/XeTile/Transforms/WgToSg/btranspose.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ gpu.module @test_gemm_btranspose{
%10 = arith.addi %8, %9 : index
%11 = xetile.init_tile %arg0[%10, %c0] : memref<16384x12288xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [8, 4], sg_data = [32, 32]>, memory_space = 0 : i32>>

//CHECK: %[[R7:.*]] = index.floordivs %[[R6]], %[[c8]]
//CHECK: %[[R7:.*]] = index.divu %[[R6]], %[[c8]]
//CHECK: %[[R8:.*]] = index.remu %[[R6]], %[[c8]]
//CHECK: %[[R9:.*]] = index.add %[[R8]], %[[c0]]
//CHECK: %[[R10:.*]] = index.remu %[[R9]], %[[c4]]
Expand Down
12 changes: 7 additions & 5 deletions test/Dialect/XeTile/Transforms/WgToSg/convert_layout.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,28 @@ gpu.module @test_convert_layout{
gpu.func @test_kernel() {
//CHECK: %[[c0:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
//CHECK: %[[c0_0:.*]] = arith.constant dense<0.000000e+00> : vector<8x256xf32>
//CHECK: %[[SLM:.*]] = memref.alloc() : memref<256x256xf32, 3>
//CHECK: %[[SLMALLOC:.*]] = memref.alloc() : memref<262144xi8, 3>
//CHECK: %[[cst_0:.*]] = arith.constant 0 : index
//CHECK: %[[SLMVIEW:.*]] = memref.view %[[SLMALLOC]][%[[cst_0]]][] : memref<262144xi8, 3> to memref<256x256xf32, 3>
//CHECK: %[[R0:.*]] = gpu.subgroup_id : index
//CHECK: %[[c4:.*]] = arith.constant 4 : index
//CHECK: %[[R1:.*]] = index.floordivs %[[R0]], %[[c4]]
//CHECK: %[[R1:.*]] = index.divu %[[R0]], %[[c4]]
//CHECK: %[[R2:.*]] = index.remu %[[R0]], %[[c4]]
//CHECK: %[[c32:.*]] = arith.constant 32 : index
//CHECK: %[[R3:.*]] = index.mul %[[R1]], %[[c32]]
//CHECK: %[[c64:.*]] = arith.constant 64 : index
//CHECK: %[[R4:.*]] = index.mul %[[R2]], %[[c64]]
//CHECK: %[[INITTILESRCMAP:.*]] = xetile.init_tile %[[SLM]][%[[R3]], %[[R4]]] : memref<256x256xf32, 3> -> !xetile.tile<32x64xf32, #xetile.tile_attr<memory_space = 3 : i32>>
//CHECK: %[[INITTILESRCMAP:.*]] = xetile.init_tile %[[SLMVIEW]][%[[R3]], %[[R4]]] : memref<256x256xf32, 3> -> !xetile.tile<32x64xf32, #xetile.tile_attr<memory_space = 3 : i32>>
//CHECK: xetile.store_tile %[[c0]], %[[INITTILESRCMAP]] : vector<32x64xf32>, !xetile.tile<32x64xf32, #xetile.tile_attr<memory_space = 3 : i32>>
//CHECK: gpu.barrier
//CHECK: %[[c1:.*]] = arith.constant 1 : index
//CHECK: %[[R5:.*]] = index.floordivs %[[R0]], %[[c1]]
//CHECK: %[[R5:.*]] = index.divu %[[R0]], %[[c1]]
//CHECK: %[[R6:.*]] = index.remu %[[R0]], %[[c1]]
//CHECK: %[[c8:.*]] = arith.constant 8 : index
//CHECK: %[[R7:.*]] = index.mul %[[R5]], %[[c8]]
//CHECK: %[[c256:.*]] = arith.constant 256 : index
//CHECK: %[[R8:.*]] = index.mul %[[R6]], %[[c256]]
//CHECK: %[[INITTILEDSTMAP:.*]] = xetile.init_tile %[[SLM]][%[[R7]], %[[R8]]] : memref<256x256xf32, 3> -> !xetile.tile<8x256xf32, #xetile.tile_attr<memory_space = 3 : i32>>
//CHECK: %[[INITTILEDSTMAP:.*]] = xetile.init_tile %[[SLMVIEW]][%[[R7]], %[[R8]]] : memref<256x256xf32, 3> -> !xetile.tile<8x256xf32, #xetile.tile_attr<memory_space = 3 : i32>>
//CHECK: %[[LOADTILE:.*]] = xetile.load_tile %[[INITTILEDSTMAP]] : !xetile.tile<8x256xf32, #xetile.tile_attr<memory_space = 3 : i32>> -> vector<8x256xf32>

%cst = arith.constant {map = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>} dense<0.000000e+00> : vector<256x256xf32>
Expand Down
Loading

0 comments on commit 6f250a4

Please sign in to comment.