Skip to content

Commit

Permalink
[BACKEND] Fix invalid intermediate IR in GPU to LLVM (#1810)
Browse files Browse the repository at this point in the history
arith.trunc op is not allowed to use index type. This causes the IR to
fail the verifier.
This doesn't cause a compilation failure as index are lowered to i32 in
the same pass.
However this creates intermediate IR that fails verifier which can make
things harder to debug.
  • Loading branch information
ThomasRaoux committed Jun 21, 2023
1 parent 4c0e3d9 commit c3cba05
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 5 deletions.
30 changes: 30 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,35 @@ struct AbsFOpConversion
}
};

/// The lowering of index_cast becomes an integer conversion since index
/// becomes an integer. If the bit width of the source and target integer
/// types is the same, just erase the cast. If the target type is wider,
/// sign-extend the value, otherwise truncate it.
struct IndexCastOpLowering
: public ElementwiseOpConversionBase<arith::IndexCastOp,
IndexCastOpLowering> {
using Base =
ElementwiseOpConversionBase<arith::IndexCastOp, IndexCastOpLowering>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

Value createDestOp(arith::IndexCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto inElemTy =
this->getTypeConverter()->convertType(getElementType(op.getIn()));
unsigned targetBits = elemTy.getIntOrFloatBitWidth();
unsigned sourceBits = inElemTy.getIntOrFloatBitWidth();

if (targetBits == sourceBits)
return operands[0];
if (targetBits < sourceBits)
return rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, elemTy,
operands[0]);
return rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, elemTy, operands[0]);
}
};

void populateElementwiseOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit) {
Expand Down Expand Up @@ -1240,6 +1269,7 @@ void populateElementwiseOpToLLVMPatterns(
patterns.add<TruncFOpConversion>(typeConverter, benefit);
patterns.add<FPToSIOpConversion>(typeConverter, benefit);
patterns.add<SIToFPOpConversion>(typeConverter, benefit);
patterns.add<IndexCastOpLowering>(typeConverter, benefit);

patterns.add<FpToFpOpConversion>(typeConverter, benefit);

Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ struct GetProgramIdOpConversion

Value blockId =
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxisAsInt()]);
rewriter.replaceOpWithNewOp<arith::TruncIOp>(op, i32_ty, blockId);
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, blockId);
return success();
}

Expand All @@ -410,7 +410,7 @@ struct GetNumProgramsOpConversion

Value blockId =
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]);
rewriter.replaceOpWithNewOp<arith::TruncIOp>(op, i32_ty, blockId);
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, blockId);

return success();
}
Expand Down
5 changes: 2 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,9 @@ class ConvertTritonGPUOpToLLVMPatternBase {
}

Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
auto tid = rewriter.create<::mlir::gpu::ThreadIdOp>(
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x);
return rewriter.create<arith::TruncIOp>(loc, i32_ty, tid);
loc, ::mlir::gpu::Dimension::x);
return rewriter.create<arith::IndexCastOp>(loc, i32_ty, tid);
}

// -----------------------------------------------------------------------
Expand Down

0 comments on commit c3cba05

Please sign in to comment.