Skip to content

Commit

Permalink
[Blocking] Rewrite blocking pass to generate small 2D Xetile Ops (#978)
Browse files Browse the repository at this point in the history
  • Loading branch information
chencha3 authored Dec 9, 2024
1 parent 1c83c2f commit 227a0f7
Show file tree
Hide file tree
Showing 11 changed files with 3,640 additions and 253 deletions.
4 changes: 4 additions & 0 deletions include/imex/Dialect/XeTile/IR/XeTileTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def XeTile : XeTile_Type<"Tile", "tile", [ShapedTypeInterface],
return llvm::cast<TileType>(cloneWith(getShape(), elementType));
}

TileType clone(llvm::ArrayRef<int64_t> shape) {
return llvm::cast<TileType>(cloneWith(shape, getElementType()));
}

xetile::SubGroupMapAttr getSgMap() {
auto encoding = llvm::dyn_cast_if_present<xetile::XeTileAttr>(getEncoding());
if (encoding)
Expand Down
6 changes: 1 addition & 5 deletions include/imex/Dialect/XeTile/Transforms/BlockingAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ class Block {

llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Block blk);

// A pair of operator and operand index number representing
// the use point of a value.
typedef std::pair<mlir::Operation *, int64_t> UsePoint;

class BlockingAnalysis {
public:
explicit BlockingAnalysis(std::shared_ptr<XeuArchInterface> uArch) {
Expand All @@ -54,7 +50,7 @@ class BlockingAnalysis {

mlir::LogicalResult run(mlir::Operation *op);

Block getUseBlockSize(mlir::Value val, UsePoint point) const;
Block getUseBlockSize(mlir::Value val, mlir::OpOperand &point) const;
Block getDefBlockSize(mlir::Value val) const;
void printAnalysisResult();

Expand Down
5 changes: 4 additions & 1 deletion include/imex/Dialect/XeTile/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def XeTileBlocking : Pass<"xetile-blocking", "::mlir::gpu::GPUModuleOp">{
let options = [
Option<"device", "device", "std::string",
/*default=*/"\"pvc\"",
"gpu platform architecture where these ops are running">
"gpu platform architecture where these ops are running">,
Option<"EnableTransform", "enable-2d-transform", "bool",
/*default=*/"false",
"Using 2D transform or 4D Conversion.">
];
}

Expand Down
24 changes: 23 additions & 1 deletion include/imex/Utils/XeCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,18 @@
#include <mlir/Transforms/DialectConversion.h>
#include <mlir/Transforms/OneToNTypeConversion.h>
using namespace mlir::xegpu;

namespace imex {

using PackFuncTy = std::function<mlir::TypedValue<mlir::VectorType>(
mlir::Value, mlir::Value, mlir::Location, mlir::OpBuilder &)>;

// A wrapper function to merge small vectors into a big one. It takes a
// range of mlir::Value objects with mlir::VectorType, and merge them
// into a big vector using the provided transformation function.
mlir::Value packVectorsWith(mlir::ValueRange ins, PackFuncTy op,
mlir::Location loc, mlir::OpBuilder &builder);

// Combine vectors vertically while keeping the logical data layout.
// As an example, given two vectors (2x4xf16) p and q, it will merge
// them in to a 4x4xf16 vector.
Expand All @@ -40,7 +50,19 @@ namespace imex {
// q5, q6, q7, q8
mlir::TypedValue<mlir::VectorType> stack(mlir::Value vecUp, mlir::Value vecDown,
mlir::Location loc,
mlir::PatternRewriter &rewriter);
mlir::OpBuilder &builder);

// merge vectors horizontally while keep the logical data layout.
// 1 2 3 4 + 10 11 12 = 1 2 3 4 10 11 12
// 5 6 7 8 13 14 15 5 6 7 8 13 14 15
// since there is no direct op in mlir exists, we will
// using ShapeCast and Shuffle to mimic it. It comes with
// cost of complex shuffle masks. the mask for the above one
// will be like this: 0 1 2 3 8 9 10
// 4 5 6 7 11 12 13
mlir::TypedValue<mlir::VectorType> concat(mlir::Value lhs, mlir::Value rhs,
mlir::Location loc,
mlir::OpBuilder &builder);

// It checks each GPUFuncOp in the module to see
// whether they have arguments and outputs with
Expand Down
20 changes: 2 additions & 18 deletions lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,6 @@

namespace imex {

using VectorTypedValue = mlir::TypedValue<mlir::VectorType>;
using funcTy = VectorTypedValue(mlir::Value, mlir::Value, mlir::Location,
mlir::PatternRewriter &);

// see its description in XeTileOpConversion.cpp
extern VectorTypedValue concat(mlir::Value v1, mlir::Value v2,
mlir::Location loc,
mlir::PatternRewriter &rewriter);

// see its description in XeTileOpConversion.cpp
extern mlir::Value mergeVectorsWrapper(mlir::ValueRange ins,
std::function<funcTy> transFunc,
mlir::Location loc,
XeOneToNPatternRewriter &rewriter);

static mlir::Value createBinOp(mlir::vector::CombiningKind kind,
mlir::Value lhs, mlir::Value rhs,
mlir::Type elemTy, mlir::Location &loc,
Expand Down Expand Up @@ -318,8 +303,7 @@ class SgVectorMultiDimReductionOpPattern
// TODO: need a better way to represent the result (align with
// unpack/pack logic). currently we just shuffle them and cast it to the
// type/shape in xetile program.
auto reducedVal =
mergeVectorsWrapper(intermediates, concat, loc, rewriter);
auto reducedVal = packVectorsWith(intermediates, concat, loc, rewriter);
auto targetTy = mlir::VectorType::get({shape[1], shape[3]}, elemTy);
auto newOp = rewriter.create<mlir::vector::ShapeCastOp>(loc, targetTy,
reducedVal);
Expand All @@ -338,7 +322,7 @@ class SgVectorMultiDimReductionOpPattern
// currently we just shuffle them and cast it to the type/shape in
// xetile program.
auto reductionVal =
mergeVectorsWrapper(intermediates, concat, loc, rewriter);
packVectorsWith(intermediates, concat, loc, rewriter);
auto targetTy = mlir::VectorType::get({shape[0], shape[2]}, elemTy);
auto newOp = rewriter.create<mlir::vector::ShapeCastOp>(loc, targetTy,
reductionVal);
Expand Down
117 changes: 8 additions & 109 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,106 +38,6 @@ using mlir::vector::ShapeCastOp;
using mlir::vector::ShuffleOp;
using mlir::vector::SplatOp;

using VectorTypedValue = mlir::TypedValue<mlir::VectorType>;
using funcTy = VectorTypedValue(mlir::Value, mlir::Value, mlir::Location,
mlir::PatternRewriter &);

// generate linearized shuffle mask for concat.
static llvm::SmallVector<int64_t>
getShuffleMask(llvm::ArrayRef<int64_t> shape1, llvm::ArrayRef<int64_t> shape2) {
assert(shape1.size() == shape2.size() && shape1.size() <= 2 &&
"only 1D/2D shape are supported.");
assert(shape1.drop_back() == shape2.drop_back() &&
"the row dim of the shapes should match.");
int64_t size1 = std::accumulate(shape1.begin(), shape1.end(), 1,
std::multiplies<int64_t>());
int64_t size2 = std::accumulate(shape2.begin(), shape2.end(), 1,
std::multiplies<int64_t>());
llvm::SmallVector<int64_t> mask(size1 + size2);
auto rows = shape1.size() == 1 ? 1 : shape1[0];
auto cols1 = shape1.size() == 1 ? shape1[0] : shape1[1];
auto cols2 = shape2.size() == 1 ? shape2[0] : shape2[1];
for (int64_t i = 0; i < rows; i++) {
int64_t s = i * (cols1 + cols2);
int64_t m = s + cols1;
int64_t e = m + cols2;
int64_t v1 = i * cols1;
int64_t v2 = size1 + i * cols2;
std::iota(mask.begin() + s, mask.begin() + m, v1);
std::iota(mask.begin() + m, mask.begin() + e, v2);
}
return mask;
}

// merge vectors horizontally while keep the logical data layout.
// 1 2 3 4 + 10 11 12 = 1 2 3 4 10 11 12
// 5 6 7 8 13 14 15 5 6 7 8 13 14 15
// since there is no direct op in mlir exists, we will
// using ShapeCast and Shuffle to mimic it. It comes with
// cost of complex shuffle masks. the mask for the above one
// will be like this: 0 1 2 3 8 9 10
// 4 5 6 7 11 12 13
VectorTypedValue concat(mlir::Value vecLeft, mlir::Value vecRight,
mlir::Location loc, mlir::PatternRewriter &rewriter) {
auto vecLeftTy = llvm::cast<mlir::VectorType>(vecLeft.getType());
auto vecRightTy = llvm::cast<mlir::VectorType>(vecRight.getType());

assert(vecLeftTy.getShape()[0] == vecLeftTy.getShape()[0] &&
"Operands of concat() do not have the same number of rows.");
assert(vecLeftTy.getRank() <= 2 &&
vecRightTy.getRank() == vecLeftTy.getRank() &&
"Currently concat only works on 1D/2D vector.");

auto elemTy = vecLeftTy.getElementType();
auto leftSize = vecLeftTy.getNumElements();
auto leftShape = vecLeftTy.getShape();
auto leftFlatTy = mlir::VectorType::get({vecLeftTy.getNumElements()}, elemTy);

auto rightSize = vecRightTy.getNumElements();
auto rightShape = vecRightTy.getShape();
auto rightFlatTy =
mlir::VectorType::get({vecRightTy.getNumElements()}, elemTy);

auto newShape = vecLeftTy.getRank() == 1
? llvm::SmallVector<int64_t>({leftSize + rightSize})
: llvm::SmallVector<int64_t>(
{leftShape[0], leftShape[1] + rightShape[1]});
auto castLeft = rewriter.create<ShapeCastOp>(loc, leftFlatTy, vecLeft);
auto castRight = rewriter.create<ShapeCastOp>(loc, rightFlatTy, vecRight);
auto mask = getShuffleMask(leftShape, rightShape);
auto shuffleOp = rewriter.create<ShuffleOp>(loc, castLeft, castRight, mask);
auto targetTy = mlir::VectorType::get(newShape, elemTy);
auto newOp = rewriter.create<ShapeCastOp>(loc, targetTy, shuffleOp);
return newOp;
}

// A wrapper function to merge small vectors into a big one. It takes a
// range of mlir::Value objects with mlir::VectorType, and merge them
// into a big vector using the provided transformation function.
mlir::Value mergeVectorsWrapper(mlir::ValueRange ins,
std::function<funcTy> transFunc,
mlir::Location loc,
XeOneToNPatternRewriter &rewriter) {
llvm::SmallVector<mlir::Value> shuffleOps(ins.begin(), ins.end());
while (shuffleOps.size() > 1) {
auto curr = shuffleOps;
shuffleOps.clear();
size_t currPairStartIdx{0};
while (currPairStartIdx < curr.size() - 1) {
size_t leftIdx{currPairStartIdx++};
size_t rightIdx{currPairStartIdx++};
auto newOp = transFunc(curr[leftIdx], curr[rightIdx], loc, rewriter);
shuffleOps.push_back(newOp);
}
if (currPairStartIdx < curr.size()) {
assert(currPairStartIdx == curr.size() - 1);
shuffleOps.push_back(curr[curr.size() - 1]);
}
}

return shuffleOps[0];
}

// Check that lowerUnpackOrPack will be able to evenly combine/split the input
// grid into the output grid.
static bool isUnpackPackCompatible(xetile::TileUnpackOp unpackOp,
Expand All @@ -164,7 +64,7 @@ static bool isUnpackPackCompatible(xetile::TileUnpackOp unpackOp,

// a unified function lowering Unpack and Pack ops.
static llvm::SmallVector<mlir::Value>
lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op,
lowerUnpackOrPack(mlir::PatternRewriter &rewriter, mlir::Location loc,
mlir::ValueRange inputs, mlir::DenseI64ArrayAttr inBlkSizes,
mlir::DenseI64ArrayAttr outBlkSizes,
llvm::ArrayRef<int64_t> inGrids,
Expand All @@ -183,8 +83,7 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op,
auto idx = i * inGrids[1] + j;
valSet.push_back(inputs[idx]);
if (valSet.size() == static_cast<size_t>(nums)) {
auto newOp =
mergeVectorsWrapper(valSet, stack, op->getLoc(), rewriter);
auto newOp = packVectorsWith(valSet, stack, loc, rewriter);
intermediates[i / nums * inGrids[1] + j] = newOp;
valSet.clear();
}
Expand All @@ -205,7 +104,7 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op,
for (auto k = 0; k < nums; k++) {
llvm::SmallVector<int64_t> offsets({k * blkSizes[0], 0});
auto newOp = rewriter.create<ExtractStridedSliceOp>(
op->getLoc(), v, offsets, blkSizes, strides);
loc, v, offsets, blkSizes, strides);
auto idx = startPos + k * inGrids[1];
intermediates[idx] = newOp;
}
Expand All @@ -228,8 +127,7 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op,
for (auto j = 0; j < interGrids[1]; j++) {
valSet.push_back(intermediates[i * interGrids[1] + j]);
if (valSet.size() == nums) {
auto newOp =
mergeVectorsWrapper(valSet, concat, op->getLoc(), rewriter);
auto newOp = packVectorsWith(valSet, concat, loc, rewriter);
newOps.push_back(newOp);
valSet.clear();
}
Expand All @@ -245,7 +143,7 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op,
for (int64_t k = 0; k < nums; k++) {
llvm::SmallVector<int64_t> offsets({0, k * blkSizes[1]});
auto newOp = rewriter.create<ExtractStridedSliceOp>(
op->getLoc(), v, offsets, blkSizes, strides);
loc, v, offsets, blkSizes, strides);
newOps.push_back(newOp);
}
}
Expand Down Expand Up @@ -291,7 +189,7 @@ class SgTileUnpackOpPattern : public XeOneToNConversion<xetile::TileUnpackOp> {
}

rewriter.setInsertionPoint(op);
auto newOps = lowerUnpackOrPack(rewriter, op, inputs, inBlkSizes,
auto newOps = lowerUnpackOrPack(rewriter, op->getLoc(), inputs, inBlkSizes,
outBlkSizes, inGrids, outGrids);

if (op->hasOneUse() && packOp && isUnpackPackCompatible(op, packOp)) {
Expand Down Expand Up @@ -327,7 +225,8 @@ class SgTilePackOpPattern : public XeOneToNConversion<xetile::TilePackOp> {
auto outGrids = outTy.getShape().take_front(2);
auto outBlkSizes = op.getInnerBlocksAttr();

auto newOps = lowerUnpackOrPack(rewriter, op, {input}, inBlkSizes,
rewriter.setInsertionPoint(op);
auto newOps = lowerUnpackOrPack(rewriter, op->getLoc(), {input}, inBlkSizes,
outBlkSizes, inGrids, outGrids);

// it is simple one-to-one mapping
Expand Down
Loading

0 comments on commit 227a0f7

Please sign in to comment.