Skip to content

Commit

Permalink
[bug fix] lower tensor descriptor as v8i32
Browse files Browse the repository at this point in the history
the backend need raw.send payload operate as a whole to work
  • Loading branch information
Dewei-Wang-sh authored and silee2 committed Nov 20, 2023
1 parent 05e7fbc commit 3aa8ce3
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 102 deletions.
4 changes: 2 additions & 2 deletions lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ void GPUXToSPIRVPass::runOnOperation() {
});
typeConverter.addConversion(
[&](xegpu::TensorDescType type) -> ::mlir::Type {
auto i64Type = ::mlir::IntegerType::get(context, 64);
return ::mlir::VectorType::get(2, i64Type);
auto i32Type = ::mlir::IntegerType::get(context, 32);
return ::mlir::VectorType::get(8, i32Type);
});
typeConverter.addConversion([&](::mlir::VectorType type) -> ::mlir::Type {
unsigned rank = type.getRank();
Expand Down
168 changes: 68 additions & 100 deletions lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,11 @@ void lookupOrInsertIntrinsic(ConversionPatternRewriter &rewriter, Operation *op,
}

/// @brief
/// convert the tensor descriptor to [2xi64] which is of the format
/// -> [base pointer: i64, offsetX: i32, offsetY: i32] for 2D tensor desc
/// -> [base pointer: i64, unused] for 1D and scattered tensor desc
/// assemble the tensor descriptor payload[8xi32] which is of the format
/// -> [base pointer, surface width, surface height, surface pitch,
/// offsetX, offsetY, blockInfo] for 2D tensor desc
/// -> [base pointer, unused] for 1D and scattered tensor desc
/// only base pointer is i64, others are i32
class CreateNdDescToVCPattern : public OpConversionPattern<CreateNdDescOp> {
public:
using OpConversionPattern<CreateNdDescOp>::OpConversionPattern;
Expand All @@ -235,9 +237,9 @@ class CreateNdDescToVCPattern : public OpConversionPattern<CreateNdDescOp> {
auto i32Type = rewriter.getI32Type();
auto i64Type = rewriter.getI64Type();
// payload
auto v4i32 = VectorType::get(4, i32Type);
auto v2i64 = VectorType::get(2, i64Type);
Value payLoad = rewriter.create<spirv::UndefOp>(loc, v2i64);
auto v8i32 = VectorType::get(8, i32Type);
auto v4i64 = VectorType::get(4, i64Type);
Value payLoad = rewriter.create<spirv::UndefOp>(loc, v4i64);
auto createIntConstant = [&](Type type, unsigned value) {
auto attr = rewriter.getIntegerAttr(type, value);
return rewriter.create<spirv::ConstantOp>(loc, type, attr);
Expand All @@ -247,10 +249,34 @@ class CreateNdDescToVCPattern : public OpConversionPattern<CreateNdDescOp> {
auto idx0 = createIntConstant(i32Type, 0);
payLoad =
rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad, base, idx0);
payLoad = rewriter.create<spirv::BitcastOp>(loc, v8i32, payLoad);
auto tileType = op.getTensorDesc().getType();
auto rank = tileType.getRank();
if (rank == 2) {
payLoad = rewriter.create<spirv::BitcastOp>(loc, v4i32, payLoad);
auto idx2 = createIntConstant(i32Type, 2);
auto idx3 = createIntConstant(i32Type, 3);
auto idx4 = createIntConstant(i32Type, 4);
auto idx5 = createIntConstant(i32Type, 5);
auto idx6 = createIntConstant(i32Type, 6);
auto idx7 = createIntConstant(i32Type, 7);
auto blockWidth = tileType.getShape()[1];
auto blockHeight = tileType.getShape()[0];
// fixme: support memref for now
auto memType = cast<MemRefType>(op.getSource().getType());
unsigned bitWidth = memType.getElementType().getIntOrFloatBitWidth();
auto surfaceWidth = memType.getShape()[1] * (bitWidth / 8) - 1;
auto surfaceHeight = memType.getShape()[0] - 1;
// fixme: pitch = width for now
auto surfacePitch = surfaceWidth;
auto surfaceW = createIntConstant(i32Type, surfaceWidth);
auto surfaceH = createIntConstant(i32Type, surfaceHeight);
auto surfaceP = createIntConstant(i32Type, surfacePitch);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
surfaceW, idx2);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
surfaceH, idx3);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
surfaceP, idx4);
auto createOffset = [&](unsigned idx) -> Value {
Value val;
if (ShapedType::isDynamic(op.getStaticOffsets()[idx])) {
Expand All @@ -263,13 +289,14 @@ class CreateNdDescToVCPattern : public OpConversionPattern<CreateNdDescOp> {
};
auto offsetX = createOffset(1);
auto offsetY = createOffset(0);
auto idx2 = createIntConstant(i32Type, 2);
auto idx3 = createIntConstant(i32Type, 3);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
offsetX, idx2);
offsetX, idx5);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
offsetY, idx3);
payLoad = rewriter.create<spirv::BitcastOp>(loc, v2i64, payLoad);
offsetY, idx6);
unsigned blockVal = ((blockHeight - 1) << 8) | (blockWidth - 1);
auto blockInfo = createIntConstant(i32Type, blockVal);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
blockInfo, idx7);
}
rewriter.replaceOp(op, payLoad);
return success();
Expand All @@ -283,32 +310,29 @@ class UpdateNDOffsetToVCPattern : public OpConversionPattern<UpdateNDOffsetOp> {
matchAndRewrite(UpdateNDOffsetOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto desc = adaptor.getTensorDesc();
auto i32Type = rewriter.getI32Type();
auto v4i32 = VectorType::get(4, i32Type);
auto v2i64 = VectorType::get(2, rewriter.getI64Type());
Value cast = rewriter.create<spirv::BitcastOp>(loc, v4i32, desc);
auto offsets = adaptor.getOffsets();
auto desc = adaptor.getTensorDesc();
for (auto i = 0; i < offsets.size(); i++) {
auto offset = offsets[i];
if (auto cst = dyn_cast<spirv::ConstantOp>(offset.getDefiningOp()))
if (auto attr = dyn_cast<mlir::IntegerAttr>(cst.getValue());
attr && attr.getInt() == 0)
continue;
auto idx2 = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getIntegerAttr(i32Type, 2));
auto idx3 = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getIntegerAttr(i32Type, 3));
Value idx = i == 0 ? idx3 : idx2;
auto idx5 = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getIntegerAttr(i32Type, 5));
auto idx6 = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getIntegerAttr(i32Type, 6));
Value idx = i == 0 ? idx6 : idx5;
auto oldOffset =
rewriter.create<spirv::VectorExtractDynamicOp>(loc, cast, idx);
rewriter.create<spirv::VectorExtractDynamicOp>(loc, desc, idx);
offset = rewriter.create<arith::TruncIOp>(loc, i32Type, offset);
auto newOffset =
rewriter.create<spirv::IAddOp>(loc, i32Type, oldOffset, offset);
cast = rewriter.create<spirv::VectorInsertDynamicOp>(loc, v4i32, cast,
newOffset, idx);
desc = rewriter.create<spirv::VectorInsertDynamicOp>(loc, desc, newOffset,
idx);
}
rewriter.replaceOpWithNewOp<spirv::BitcastOp>(op, v2i64, cast);
rewriter.replaceOp(op, desc);
return success();
}
};
Expand All @@ -322,14 +346,16 @@ class CreateDescToVCPattern : public OpConversionPattern<CreateDescOp> {
auto loc = op.getLoc();
auto i32Type = rewriter.getI32Type();
auto i64Type = rewriter.getI64Type();
auto v2i64 = VectorType::get(2, i64Type);
Value payLoad = rewriter.create<spirv::UndefOp>(loc, v2i64);
auto v8i32 = VectorType::get(8, i32Type);
auto v4i64 = VectorType::get(4, i64Type);
Value payLoad = rewriter.create<spirv::UndefOp>(loc, v4i64);
auto base = rewriter.create<spirv::ConvertPtrToUOp>(loc, i64Type,
adaptor.getSource());
auto idx0 = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getIntegerAttr(i32Type, 0));
payLoad =
rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad, base, idx0);
payLoad = rewriter.create<spirv::BitcastOp>(loc, v8i32, payLoad);
rewriter.replaceOp(op, payLoad);
return success();
}
Expand Down Expand Up @@ -369,6 +395,8 @@ class LoadStorePrefetchNdToLsc : public OpConversionPattern<OpType> {
auto i8Type = rewriter.getI8Type();
auto i16Type = rewriter.getI16Type();
auto i32Type = rewriter.getI32Type();
auto i64Type = rewriter.getI64Type();
auto v4i64 = VectorType::get(4, i64Type);
auto vnni = false;
auto transpose = false;
if constexpr (isLoad) {
Expand Down Expand Up @@ -396,11 +424,9 @@ class LoadStorePrefetchNdToLsc : public OpConversionPattern<OpType> {
auto nBlks = createIntConstant(i8Type, 1);
auto tensorDesc = adaptor.getTensorDesc();
auto idx0 = createIntConstant(i32Type, 0);
auto base =
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx0);
std::string typeStr;
VectorType newType;
std::tie(typeStr, newType) = encodeVectorType(rewriter, vecType, rank == 1);
auto cast = rewriter.create<spirv::BitcastOp>(loc, v4i64, tensorDesc);
auto base = rewriter.create<spirv::VectorExtractDynamicOp>(loc, cast, idx0);
auto [typeStr, newType] = encodeVectorType(rewriter, vecType, rank == 1);
SmallVector<Value> args;
if (rank == 2) {
auto blockWidth = tileType.getShape()[1];
Expand All @@ -411,7 +437,7 @@ class LoadStorePrefetchNdToLsc : public OpConversionPattern<OpType> {
// static memref for now
auto createDescOp =
op.getTensorDesc().template getDefiningOp<CreateNdDescOp>();
auto memType = cast<MemRefType>(createDescOp.getSource().getType());
auto memType = llvm::cast<MemRefType>(createDescOp.getSource().getType());
unsigned bitWidth = memType.getElementType().getIntOrFloatBitWidth();
auto surfaceWidth = memType.getShape()[1] * (bitWidth / 8) - 1;
auto surfaceHeight = memType.getShape()[0] - 1;
Expand All @@ -420,14 +446,12 @@ class LoadStorePrefetchNdToLsc : public OpConversionPattern<OpType> {
auto surfaceW = createIntConstant(i32Type, surfaceWidth);
auto surfaceH = createIntConstant(i32Type, surfaceHeight);
auto surfaceP = createIntConstant(i32Type, surfacePitch);
auto v4i32 = VectorType::get(4, i32Type);
tensorDesc = rewriter.create<spirv::BitcastOp>(loc, v4i32, tensorDesc);
auto idx2 = createIntConstant(i32Type, 2);
auto idx3 = createIntConstant(i32Type, 3);
auto idx5 = createIntConstant(i32Type, 5);
auto idx6 = createIntConstant(i32Type, 6);
auto offsetX =
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx2);
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx5);
auto offsetY =
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx3);
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx6);
args.assign({pred, l1CacheHint, l3CacheHint, dataum, trans, nBlks, blockW,
blockH, transform, base, surfaceW, surfaceH, surfaceP,
offsetX, offsetY});
Expand Down Expand Up @@ -539,7 +563,6 @@ class LoadStorePrefetchNdToRawSend : public OpConversionPattern<OpType> {
auto i1Type = rewriter.getI1Type();
auto i8Type = rewriter.getI8Type();
auto i32Type = rewriter.getI32Type();
auto i64Type = rewriter.getI64Type();
auto vnni = false;
auto transpose = false;
if constexpr (isLoad) {
Expand Down Expand Up @@ -607,66 +630,7 @@ class LoadStorePrefetchNdToRawSend : public OpConversionPattern<OpType> {
rawSendMsg |= 1 << 25;
}
auto msg = createIntConstant(i32Type, rawSendMsg);
// payload
// payload is v8i32 = [base:i64, surfaceWidth:i32, surfaceHeight:i32,
// surefacePitch:i32, offsetX:i32, offsetY:i32, blockInfo:i32]
// the base/surfaceInfo/blockInfo are staticly from the tensor desc
// while the offsetX/Y are dynamicly udpated
auto insertPoint = rewriter.saveInsertionPoint();
CreateNdDescOp createDescOp = *findDescOp(op.getTensorDesc());
rewriter.setInsertionPointAfter(createDescOp);
auto v8i32 = VectorType::get(8, i32Type);
auto v4i64 = VectorType::get(4, i64Type);
Value payLoad = rewriter.create<spirv::UndefOp>(loc, v4i64);
auto idx0 = createIntConstant(i32Type, 0);
auto desc = rewriter.getRemappedValue(createDescOp);
auto base = rewriter.create<spirv::VectorExtractDynamicOp>(loc, desc, idx0);
payLoad =
rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad, base, idx0);
payLoad = rewriter.create<spirv::BitcastOp>(loc, v8i32, payLoad);
if (rank == 2) {
auto idx2 = createIntConstant(i32Type, 2);
auto idx3 = createIntConstant(i32Type, 3);
auto idx4 = createIntConstant(i32Type, 4);
auto idx5 = createIntConstant(i32Type, 5);
auto idx6 = createIntConstant(i32Type, 6);
auto idx7 = createIntConstant(i32Type, 7);
auto blockWidth = tileType.getShape()[1];
auto blockHeight = tileType.getShape()[0];
// fixme: support memref for now
auto memType = cast<MemRefType>(createDescOp.getSource().getType());
unsigned bitWidth = memType.getElementType().getIntOrFloatBitWidth();
auto surfaceWidth = memType.getShape()[1] * (bitWidth / 8) - 1;
auto surfaceHeight = memType.getShape()[0] - 1;
// fixme: pitch = width for now
auto surfacePitch = surfaceWidth;
auto surfaceW = createIntConstant(i32Type, surfaceWidth);
auto surfaceH = createIntConstant(i32Type, surfaceHeight);
auto surfaceP = createIntConstant(i32Type, surfacePitch);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
surfaceW, idx2);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
surfaceH, idx3);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
surfaceP, idx4);
unsigned blockVal = ((blockHeight - 1) << 8) | (blockWidth - 1);
auto blockInfo = createIntConstant(i32Type, blockVal);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
blockInfo, idx7);
rewriter.restoreInsertionPoint(insertPoint);
auto v4i32 = VectorType::get(4, i32Type);
auto tensorDesc = adaptor.getTensorDesc();
tensorDesc = rewriter.create<spirv::BitcastOp>(loc, v4i32, tensorDesc);
auto offsetX =
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx2);
auto offsetY =
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx3);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
offsetX, idx5);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
offsetY, idx6);
}
rewriter.restoreInsertionPoint(insertPoint);
auto payLoad = adaptor.getTensorDesc();
SmallVector<Value> args{modifier, execSize, pred, numSrc1, numDst,
sfid, extMsg, msg, payLoad};
if constexpr (isLoad) {
Expand Down Expand Up @@ -796,8 +760,10 @@ class GatherScatterToRawSend : public OpConversionPattern<OpType> {
auto i8Type = rewriter.getI8Type();
auto i32Type = rewriter.getI32Type();
auto i64Type = rewriter.getI64Type();
auto v4i64 = VectorType::get(4, i64Type);
auto tensorDesc = adaptor.getTensorDesc();
auto idx0 = createIntConstant(i32Type, 0);
tensorDesc = rewriter.create<spirv::BitcastOp>(loc, v4i64, tensorDesc);
auto base =
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx0);
VectorType newType = VectorType::get(1, i32Type);
Expand Down Expand Up @@ -908,6 +874,7 @@ class AtomicToLsc : public OpConversionPattern<AtomicRMWOp> {
auto i16Type = rewriter.getI16Type();
auto i32Type = rewriter.getI32Type();
auto i64Type = rewriter.getI64Type();
auto v4i64 = VectorType::get(4, i64Type);
VectorType vecType = cast<VectorType>(op.getResult().getType());
std::string funcName = "llvm_genx_lsc_xatomic_stateless_";
auto [typeStr, newType] = encodeVectorType(rewriter, vecType, false, true);
Expand Down Expand Up @@ -936,6 +903,7 @@ class AtomicToLsc : public OpConversionPattern<AtomicRMWOp> {
auto mask = createIntConstant(i8Type, 0);

auto tensorDesc = adaptor.getTensorDesc();
tensorDesc = rewriter.create<spirv::BitcastOp>(loc, v4i64, tensorDesc);
auto idx0 = createIntConstant(i32Type, 0);
auto base =
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx0);
Expand Down
Loading

0 comments on commit 3aa8ce3

Please sign in to comment.