Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Blocking] Rewrite blocking pass to generate small 2D Xetile Ops #978

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/imex/Dialect/XeTile/IR/XeTileTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def XeTile : XeTile_Type<"Tile", "tile", [ShapedTypeInterface],
return llvm::cast<TileType>(cloneWith(getShape(), elementType));
}

TileType clone(llvm::ArrayRef<int64_t> shape) {
return llvm::cast<TileType>(cloneWith(shape, getElementType()));
}

xetile::SubGroupMapAttr getSgMap() {
auto encoding = llvm::dyn_cast_if_present<xetile::XeTileAttr>(getEncoding());
if (encoding)
Expand Down
6 changes: 1 addition & 5 deletions include/imex/Dialect/XeTile/Transforms/BlockingAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ class Block {

llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Block blk);

// A pair of operator and operand index number representing
// the use point of a value.
typedef std::pair<mlir::Operation *, int64_t> UsePoint;

class BlockingAnalysis {
public:
explicit BlockingAnalysis(std::shared_ptr<XeuArchInterface> uArch) {
Expand All @@ -54,7 +50,7 @@ class BlockingAnalysis {

mlir::LogicalResult run(mlir::Operation *op);

Block getUseBlockSize(mlir::Value val, UsePoint point) const;
Block getUseBlockSize(mlir::Value val, mlir::OpOperand &point) const;
Block getDefBlockSize(mlir::Value val) const;
void printAnalysisResult();

Expand Down
5 changes: 4 additions & 1 deletion include/imex/Dialect/XeTile/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def XeTileBlocking : Pass<"xetile-blocking", "::mlir::gpu::GPUModuleOp">{
let options = [
Option<"device", "device", "std::string",
/*default=*/"\"pvc\"",
"gpu platform architecture where these ops are running">
"gpu platform architecture where these ops are running">,
Option<"EnableTransform", "enable-2d-transform", "bool",
/*default=*/"false",
"Using 2D transform or 4D Conversion.">
];
}

Expand Down
24 changes: 23 additions & 1 deletion include/imex/Utils/XeCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,18 @@
#include <mlir/Transforms/DialectConversion.h>
#include <mlir/Transforms/OneToNTypeConversion.h>
using namespace mlir::xegpu;

namespace imex {

using PackFuncTy = std::function<mlir::TypedValue<mlir::VectorType>(
mlir::Value, mlir::Value, mlir::Location, mlir::OpBuilder &)>;

// A wrapper function to merge small vectors into a big one. It takes a
// range of mlir::Value objects with mlir::VectorType, and merge them
// into a big vector using the provided transformation function.
mlir::Value packVectorsWith(mlir::ValueRange ins, PackFuncTy op,
mlir::Location loc, mlir::OpBuilder &builder);

// Combine vectors vertically while keeping the logical data layout.
// As an example, given two vectors (2x4xf16) p and q, it will merge
// them in to a 4x4xf16 vector.
Expand All @@ -40,7 +50,19 @@ namespace imex {
// q5, q6, q7, q8
mlir::TypedValue<mlir::VectorType> stack(mlir::Value vecUp, mlir::Value vecDown,
mlir::Location loc,
mlir::PatternRewriter &rewriter);
mlir::OpBuilder &builder);

// merge vectors horizontally while keep the logical data layout.
// 1 2 3 4 + 10 11 12 = 1 2 3 4 10 11 12
// 5 6 7 8 13 14 15 5 6 7 8 13 14 15
// since there is no direct op in mlir exists, we will
// using ShapeCast and Shuffle to mimic it. It comes with
// cost of complex shuffle masks. the mask for the above one
// will be like this: 0 1 2 3 8 9 10
// 4 5 6 7 11 12 13
mlir::TypedValue<mlir::VectorType> concat(mlir::Value lhs, mlir::Value rhs,
mlir::Location loc,
mlir::OpBuilder &builder);

// It checks each GPUFuncOp in the module to see
// whether they have arguments and outputs with
Expand Down
20 changes: 2 additions & 18 deletions lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,6 @@

namespace imex {

using VectorTypedValue = mlir::TypedValue<mlir::VectorType>;
using funcTy = VectorTypedValue(mlir::Value, mlir::Value, mlir::Location,
mlir::PatternRewriter &);

// see its description in XeTileOpConversion.cpp
extern VectorTypedValue concat(mlir::Value v1, mlir::Value v2,
mlir::Location loc,
mlir::PatternRewriter &rewriter);

// see its description in XeTileOpConversion.cpp
extern mlir::Value mergeVectorsWrapper(mlir::ValueRange ins,
std::function<funcTy> transFunc,
mlir::Location loc,
XeOneToNPatternRewriter &rewriter);

static mlir::Value createBinOp(mlir::vector::CombiningKind kind,
mlir::Value lhs, mlir::Value rhs,
mlir::Type elemTy, mlir::Location &loc,
Expand Down Expand Up @@ -318,8 +303,7 @@ class SgVectorMultiDimReductionOpPattern
// TODO: need a better way to represent the result (align with
// unpack/pack logic). currently we just shuffle them and cast it to the
// type/shape in xetile program.
auto reducedVal =
mergeVectorsWrapper(intermediates, concat, loc, rewriter);
auto reducedVal = packVectorsWith(intermediates, concat, loc, rewriter);
auto targetTy = mlir::VectorType::get({shape[1], shape[3]}, elemTy);
auto newOp = rewriter.create<mlir::vector::ShapeCastOp>(loc, targetTy,
reducedVal);
Expand All @@ -338,7 +322,7 @@ class SgVectorMultiDimReductionOpPattern
// currently we just shuffle them and cast it to the type/shape in
// xetile program.
auto reductionVal =
mergeVectorsWrapper(intermediates, concat, loc, rewriter);
packVectorsWith(intermediates, concat, loc, rewriter);
auto targetTy = mlir::VectorType::get({shape[0], shape[2]}, elemTy);
auto newOp = rewriter.create<mlir::vector::ShapeCastOp>(loc, targetTy,
reductionVal);
Expand Down
117 changes: 8 additions & 109 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,106 +38,6 @@ using mlir::vector::ShapeCastOp;
using mlir::vector::ShuffleOp;
using mlir::vector::SplatOp;

using VectorTypedValue = mlir::TypedValue<mlir::VectorType>;
using funcTy = VectorTypedValue(mlir::Value, mlir::Value, mlir::Location,
mlir::PatternRewriter &);

// generate linearized shuffle mask for concat.
static llvm::SmallVector<int64_t>
getShuffleMask(llvm::ArrayRef<int64_t> shape1, llvm::ArrayRef<int64_t> shape2) {
assert(shape1.size() == shape2.size() && shape1.size() <= 2 &&
"only 1D/2D shape are supported.");
assert(shape1.drop_back() == shape2.drop_back() &&
"the row dim of the shapes should match.");
int64_t size1 = std::accumulate(shape1.begin(), shape1.end(), 1,
std::multiplies<int64_t>());
int64_t size2 = std::accumulate(shape2.begin(), shape2.end(), 1,
std::multiplies<int64_t>());
llvm::SmallVector<int64_t> mask(size1 + size2);
auto rows = shape1.size() == 1 ? 1 : shape1[0];
auto cols1 = shape1.size() == 1 ? shape1[0] : shape1[1];
auto cols2 = shape2.size() == 1 ? shape2[0] : shape2[1];
for (int64_t i = 0; i < rows; i++) {
int64_t s = i * (cols1 + cols2);
int64_t m = s + cols1;
int64_t e = m + cols2;
int64_t v1 = i * cols1;
int64_t v2 = size1 + i * cols2;
std::iota(mask.begin() + s, mask.begin() + m, v1);
std::iota(mask.begin() + m, mask.begin() + e, v2);
}
return mask;
}

// merge vectors horizontally while keep the logical data layout.
// 1 2 3 4 + 10 11 12 = 1 2 3 4 10 11 12
// 5 6 7 8 13 14 15 5 6 7 8 13 14 15
// since there is no direct op in mlir exists, we will
// using ShapeCast and Shuffle to mimic it. It comes with
// cost of complex shuffle masks. the mask for the above one
// will be like this: 0 1 2 3 8 9 10
// 4 5 6 7 11 12 13
VectorTypedValue concat(mlir::Value vecLeft, mlir::Value vecRight,
mlir::Location loc, mlir::PatternRewriter &rewriter) {
auto vecLeftTy = llvm::cast<mlir::VectorType>(vecLeft.getType());
auto vecRightTy = llvm::cast<mlir::VectorType>(vecRight.getType());

assert(vecLeftTy.getShape()[0] == vecLeftTy.getShape()[0] &&
"Operands of concat() do not have the same number of rows.");
assert(vecLeftTy.getRank() <= 2 &&
vecRightTy.getRank() == vecLeftTy.getRank() &&
"Currently concat only works on 1D/2D vector.");

auto elemTy = vecLeftTy.getElementType();
auto leftSize = vecLeftTy.getNumElements();
auto leftShape = vecLeftTy.getShape();
auto leftFlatTy = mlir::VectorType::get({vecLeftTy.getNumElements()}, elemTy);

auto rightSize = vecRightTy.getNumElements();
auto rightShape = vecRightTy.getShape();
auto rightFlatTy =
mlir::VectorType::get({vecRightTy.getNumElements()}, elemTy);

auto newShape = vecLeftTy.getRank() == 1
? llvm::SmallVector<int64_t>({leftSize + rightSize})
: llvm::SmallVector<int64_t>(
{leftShape[0], leftShape[1] + rightShape[1]});
auto castLeft = rewriter.create<ShapeCastOp>(loc, leftFlatTy, vecLeft);
auto castRight = rewriter.create<ShapeCastOp>(loc, rightFlatTy, vecRight);
auto mask = getShuffleMask(leftShape, rightShape);
auto shuffleOp = rewriter.create<ShuffleOp>(loc, castLeft, castRight, mask);
auto targetTy = mlir::VectorType::get(newShape, elemTy);
auto newOp = rewriter.create<ShapeCastOp>(loc, targetTy, shuffleOp);
return newOp;
}

// A wrapper function to merge small vectors into a big one. It takes a
// range of mlir::Value objects with mlir::VectorType, and merge them
// into a big vector using the provided transformation function.
mlir::Value mergeVectorsWrapper(mlir::ValueRange ins,
std::function<funcTy> transFunc,
mlir::Location loc,
XeOneToNPatternRewriter &rewriter) {
llvm::SmallVector<mlir::Value> shuffleOps(ins.begin(), ins.end());
while (shuffleOps.size() > 1) {
auto curr = shuffleOps;
shuffleOps.clear();
size_t currPairStartIdx{0};
while (currPairStartIdx < curr.size() - 1) {
size_t leftIdx{currPairStartIdx++};
size_t rightIdx{currPairStartIdx++};
auto newOp = transFunc(curr[leftIdx], curr[rightIdx], loc, rewriter);
shuffleOps.push_back(newOp);
}
if (currPairStartIdx < curr.size()) {
assert(currPairStartIdx == curr.size() - 1);
shuffleOps.push_back(curr[curr.size() - 1]);
}
}

return shuffleOps[0];
}

// Check that lowerUnpackOrPack will be able to evenly combine/split the input
// grid into the output grid.
static bool isUnpackPackCompatible(xetile::TileUnpackOp unpackOp,
Expand All @@ -164,7 +64,7 @@ static bool isUnpackPackCompatible(xetile::TileUnpackOp unpackOp,

// a unified function lowering Unpack and Pack ops.
static llvm::SmallVector<mlir::Value>
lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op,
lowerUnpackOrPack(mlir::PatternRewriter &rewriter, mlir::Location loc,
mlir::ValueRange inputs, mlir::DenseI64ArrayAttr inBlkSizes,
mlir::DenseI64ArrayAttr outBlkSizes,
llvm::ArrayRef<int64_t> inGrids,
Expand All @@ -183,8 +83,7 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op,
auto idx = i * inGrids[1] + j;
valSet.push_back(inputs[idx]);
if (valSet.size() == static_cast<size_t>(nums)) {
auto newOp =
mergeVectorsWrapper(valSet, stack, op->getLoc(), rewriter);
auto newOp = packVectorsWith(valSet, stack, loc, rewriter);
intermediates[i / nums * inGrids[1] + j] = newOp;
valSet.clear();
}
Expand All @@ -205,7 +104,7 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op,
for (auto k = 0; k < nums; k++) {
llvm::SmallVector<int64_t> offsets({k * blkSizes[0], 0});
auto newOp = rewriter.create<ExtractStridedSliceOp>(
op->getLoc(), v, offsets, blkSizes, strides);
loc, v, offsets, blkSizes, strides);
auto idx = startPos + k * inGrids[1];
intermediates[idx] = newOp;
}
Expand All @@ -228,8 +127,7 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op,
for (auto j = 0; j < interGrids[1]; j++) {
valSet.push_back(intermediates[i * interGrids[1] + j]);
if (valSet.size() == nums) {
auto newOp =
mergeVectorsWrapper(valSet, concat, op->getLoc(), rewriter);
auto newOp = packVectorsWith(valSet, concat, loc, rewriter);
newOps.push_back(newOp);
valSet.clear();
}
Expand All @@ -245,7 +143,7 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op,
for (int64_t k = 0; k < nums; k++) {
llvm::SmallVector<int64_t> offsets({0, k * blkSizes[1]});
auto newOp = rewriter.create<ExtractStridedSliceOp>(
op->getLoc(), v, offsets, blkSizes, strides);
loc, v, offsets, blkSizes, strides);
newOps.push_back(newOp);
}
}
Expand Down Expand Up @@ -291,7 +189,7 @@ class SgTileUnpackOpPattern : public XeOneToNConversion<xetile::TileUnpackOp> {
}

rewriter.setInsertionPoint(op);
auto newOps = lowerUnpackOrPack(rewriter, op, inputs, inBlkSizes,
auto newOps = lowerUnpackOrPack(rewriter, op->getLoc(), inputs, inBlkSizes,
outBlkSizes, inGrids, outGrids);

if (op->hasOneUse() && packOp && isUnpackPackCompatible(op, packOp)) {
Expand Down Expand Up @@ -327,7 +225,8 @@ class SgTilePackOpPattern : public XeOneToNConversion<xetile::TilePackOp> {
auto outGrids = outTy.getShape().take_front(2);
auto outBlkSizes = op.getInnerBlocksAttr();

auto newOps = lowerUnpackOrPack(rewriter, op, {input}, inBlkSizes,
rewriter.setInsertionPoint(op);
auto newOps = lowerUnpackOrPack(rewriter, op->getLoc(), {input}, inBlkSizes,
outBlkSizes, inGrids, outGrids);

// it is simple one-to-one mapping
Expand Down
Loading
Loading