Skip to content

Commit

Permalink
[fix rebase] adjust code on top of new tensor descriptor (#670)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dewei-Wang-sh authored Nov 21, 2023
1 parent 7ac6b72 commit 605ac35
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 67 deletions.
82 changes: 16 additions & 66 deletions lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ void lookupOrInsertIntrinsic(ConversionPatternRewriter &rewriter, Operation *op,
/// offsetX, offsetY, blockInfo] for 2D tensor desc
/// -> [base pointer, unused] for 1D and scattered tensor desc
/// only base pointer is i64, others are i32
class CreateNdDescToVCPattern : public OpConversionPattern<CreateNdDescOp> {
class CreateNdDescToSPIRV : public OpConversionPattern<CreateNdDescOp> {
public:
using OpConversionPattern<CreateNdDescOp>::OpConversionPattern;
LogicalResult
Expand Down Expand Up @@ -313,7 +313,7 @@ class UpdateNDOffsetToVCPattern : public OpConversionPattern<UpdateNDOffsetOp> {
auto i32Type = rewriter.getI32Type();
auto offsets = adaptor.getOffsets();
auto desc = adaptor.getTensorDesc();
for (auto i = 0; i < offsets.size(); i++) {
for (size_t i = 0; i < offsets.size(); i++) {
auto offset = offsets[i];
if (auto cst = dyn_cast<spirv::ConstantOp>(offset.getDefiningOp()))
if (auto attr = dyn_cast<mlir::IntegerAttr>(cst.getValue());
Expand Down Expand Up @@ -1198,7 +1198,7 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {

void imex::populateXeGPUToVCIntrinsicsPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<CreateNdDescToVCPattern, CreateDescToVCPattern, DpasToVCPattern,
patterns.add<CreateNdDescToSPIRV, CreateDescToVCPattern, DpasToVCPattern,
AllocNbarrierToVCPattern, CreateNbarrierToVCPattern,
NbarrierArriveToVCPattern, NbarrierWaitToVCPattern,
CompilerHintToVCPattern, MfenceToVCPattern, VectorShapeCast,
Expand Down Expand Up @@ -1245,56 +1245,6 @@ encodeGenISAVectorType(ConversionPatternRewriter &rewriter, VectorType type,
return std::make_pair(str, newType);
}

class CreateNdDescToGenISA : public OpConversionPattern<CreateNdDescOp> {
public:
using OpConversionPattern<CreateNdDescOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(CreateNdDescOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto i32Type = rewriter.getI32Type();
auto i64Type = rewriter.getI64Type();
auto v4i32 = VectorType::get(4, i32Type);
auto v2i64 = VectorType::get(2, i64Type);
Value payLoad = rewriter.create<spirv::UndefOp>(loc, v2i64);
auto createIntConstant = [&](Type type, unsigned value) {
auto attr = rewriter.getIntegerAttr(type, value);
return rewriter.create<spirv::ConstantOp>(loc, type, attr);
};
auto base = rewriter.create<spirv::ConvertPtrToUOp>(loc, i64Type,
adaptor.getSource());
auto idx0 = createIntConstant(i32Type, 0);
payLoad =
rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad, base, idx0);
auto tileType = op.getTensorDesc().getType();
auto rank = tileType.getRank();
if (rank == 2) {
payLoad = rewriter.create<spirv::BitcastOp>(loc, v4i32, payLoad);
auto createOffset = [&](unsigned idx) -> Value {
Value val;
if (ShapedType::isDynamic(op.getStaticOffsets()[idx])) {
val = op.getOffsets()[idx];
val = rewriter.create<arith::TruncIOp>(loc, i32Type, val);
} else {
val = createIntConstant(i32Type, op.getStaticOffsets()[idx]);
}
return val;
};
auto offsetX = createOffset(1);
auto offsetY = createOffset(0);
auto idx2 = createIntConstant(i32Type, 2);
auto idx3 = createIntConstant(i32Type, 3);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
offsetX, idx2);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
offsetY, idx3);
payLoad = rewriter.create<spirv::BitcastOp>(loc, v2i64, payLoad);
}
rewriter.replaceOp(op, payLoad);
return success();
}
};

template <typename OpType>
class LoadStorePrefetchNdToGenISA : public OpConversionPattern<OpType> {
public:
Expand Down Expand Up @@ -1330,6 +1280,8 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern<OpType> {
auto i1Type = rewriter.getI1Type();
auto i8Type = rewriter.getI8Type();
auto i32Type = rewriter.getI32Type();
auto i64Type = rewriter.getI64Type();
auto v4i64 = VectorType::get(4, i64Type);
auto vnni = false;
auto transpose = false;
if constexpr (isLoad) {
Expand All @@ -1347,8 +1299,8 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern<OpType> {
auto nBlks = createIntConstant(i8Type, 1);
auto tensorDesc = adaptor.getTensorDesc();
auto idx0 = createIntConstant(i32Type, 0);
auto base =
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx0);
auto cast = rewriter.create<spirv::BitcastOp>(loc, v4i64, tensorDesc);
auto base = rewriter.create<spirv::VectorExtractDynamicOp>(loc, cast, idx0);
auto [typeStr, newType] = encodeGenISAVectorType(rewriter, vecType, false);
SmallVector<Value> args;
if (rank == 2) {
Expand All @@ -1360,7 +1312,7 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern<OpType> {
// static memref for now
auto createDescOp =
op.getTensorDesc().template getDefiningOp<CreateNdDescOp>();
auto memType = cast<MemRefType>(createDescOp.getSource().getType());
auto memType = llvm::cast<MemRefType>(createDescOp.getSource().getType());
unsigned bitWidth = memType.getElementType().getIntOrFloatBitWidth();
auto surfaceWidth = memType.getShape()[1] * (bitWidth / 8) - 1;
auto surfaceHeight = memType.getShape()[0] - 1;
Expand All @@ -1369,14 +1321,12 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern<OpType> {
auto surfaceW = createIntConstant(i32Type, surfaceWidth);
auto surfaceH = createIntConstant(i32Type, surfaceHeight);
auto surfaceP = createIntConstant(i32Type, surfacePitch);
auto v4i32 = VectorType::get(4, i32Type);
tensorDesc = rewriter.create<spirv::BitcastOp>(loc, v4i32, tensorDesc);
auto idx2 = createIntConstant(i32Type, 2);
auto idx3 = createIntConstant(i32Type, 3);
auto idx5 = createIntConstant(i32Type, 5);
auto idx6 = createIntConstant(i32Type, 6);
auto offsetX =
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx2);
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx5);
auto offsetY =
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx3);
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx6);
args.assign({base, surfaceW, surfaceH, surfaceP, offsetX, offsetY,
elemSize, blockW, blockH, nBlks, trans, transform});
if constexpr (!isLoad && !isPrefetch) {
Expand All @@ -1391,7 +1341,7 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern<OpType> {
auto funcType =
rewriter.getFunctionType(ValueRange(args).getTypes(), newType);
Operation *opPtr = op;
lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true);
lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, false);
auto funcOp =
rewriter.create<spirv::FunctionCallOp>(loc, newType, funcName, args);
auto castTy = this->getTypeConverter()->convertType(op.getType());
Expand All @@ -1401,7 +1351,7 @@ class LoadStorePrefetchNdToGenISA : public OpConversionPattern<OpType> {
} else {
auto funcType = rewriter.getFunctionType(ValueRange(args).getTypes(), {});
Operation *opPtr = op;
lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true);
lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, false);
rewriter.create<spirv::FunctionCallOp>(loc, TypeRange(), funcName, args);
rewriter.eraseOp(op);
}
Expand Down Expand Up @@ -1472,7 +1422,7 @@ class DpasToGenISA : public OpConversionPattern<DpasOp> {
auto funcType =
rewriter.getFunctionType(ValueRange(args).getTypes(), newType);
Operation *opPtr = op;
lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true);
lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, false);
auto funcOp =
rewriter.create<spirv::FunctionCallOp>(loc, newType, funcName, args);
rewriter.replaceOp(op, funcOp);
Expand All @@ -1482,7 +1432,7 @@ class DpasToGenISA : public OpConversionPattern<DpasOp> {

void imex::populateXeGPUToGenISAPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<CreateNdDescToGenISA, DpasToGenISA,
patterns.add<CreateNdDescToSPIRV, DpasToGenISA,
LoadStorePrefetchNdToGenISA<LoadNDOp>,
LoadStorePrefetchNdToGenISA<StoreNDOp>,
LoadStorePrefetchNdToGenISA<PrefetchNDOp>>(
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/XeGPUToSPIRV/gemm_basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ module @gemm attributes {gpu.container_module} {
%1 = memref.get_global @__constant_16x16xf16 : memref<16x16xf16>
%2 = call @test(%0, %1) : (memref<8x16xf16>, memref<16x16xf16>) -> memref<8x16xf32>
%cast = memref.cast %2 : memref<8x16xf32> to memref<*xf32>
//call @printMemrefF32(%cast) : (memref<*xf32>) -> ()
// call @printMemrefF32(%cast) : (memref<*xf32>) -> ()
return
}
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
Expand Down

0 comments on commit 605ac35

Please sign in to comment.