Skip to content

Commit 9dbce79

Browse files
KritooooKritoooo
authored andcommitted
[CIR][ThroughMLIR] fix BinOp, CmpOp Lowering to MLIR and lowering cir.vec.cmp to MLIR (#694)
This PR does Three things: 1. Fixes the BinOp lowering to MLIR issue where signed numbers were not handled correctly, and adds support for vector types. The corresponding test files have been modified. 2. Fixes the CmpOp lowering to MLIR issue where signed numbers were not handled correctly And modified test files. 3. Adds cir.vec.cmp lowering to MLIR along with the corresponding test files. I originally planned to complete the remaining cir.vec.* lowerings in this PR, but it seems there's quite a lot to do, so I'll split it into multiple PRs. --------- Co-authored-by: Kritoooo <[email protected]>
1 parent 9ea95d4 commit 9dbce79

File tree

11 files changed

+518
-392
lines changed

11 files changed

+518
-392
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 81 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -628,52 +628,60 @@ class CIRBinOpLowering : public mlir::OpConversionPattern<mlir::cir::BinOp> {
628628
"inconsistent operands' types not supported yet");
629629
mlir::Type mlirType = getTypeConverter()->convertType(op.getType());
630630
assert((mlirType.isa<mlir::IntegerType>() ||
631-
mlirType.isa<mlir::FloatType>()) &&
631+
mlirType.isa<mlir::FloatType>() ||
632+
mlirType.isa<mlir::VectorType>()) &&
632633
"operand type not supported yet");
633634

635+
auto type = op.getLhs().getType();
636+
if (auto VecType = type.dyn_cast<mlir::cir::VectorType>()) {
637+
type = VecType.getEltType();
638+
}
639+
634640
switch (op.getKind()) {
635641
case mlir::cir::BinOpKind::Add:
636-
if (mlirType.isa<mlir::IntegerType>())
642+
if (type.isa<mlir::cir::IntType>())
637643
rewriter.replaceOpWithNewOp<mlir::arith::AddIOp>(
638644
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
639645
else
640646
rewriter.replaceOpWithNewOp<mlir::arith::AddFOp>(
641647
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
642648
break;
643649
case mlir::cir::BinOpKind::Sub:
644-
if (mlirType.isa<mlir::IntegerType>())
650+
if (type.isa<mlir::cir::IntType>())
645651
rewriter.replaceOpWithNewOp<mlir::arith::SubIOp>(
646652
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
647653
else
648654
rewriter.replaceOpWithNewOp<mlir::arith::SubFOp>(
649655
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
650656
break;
651657
case mlir::cir::BinOpKind::Mul:
652-
if (mlirType.isa<mlir::IntegerType>())
658+
if (type.isa<mlir::cir::IntType>())
653659
rewriter.replaceOpWithNewOp<mlir::arith::MulIOp>(
654660
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
655661
else
656662
rewriter.replaceOpWithNewOp<mlir::arith::MulFOp>(
657663
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
658664
break;
659665
case mlir::cir::BinOpKind::Div:
660-
if (mlirType.isa<mlir::IntegerType>()) {
661-
if (mlirType.isSignlessInteger())
666+
if (auto ty = type.dyn_cast<mlir::cir::IntType>()) {
667+
if (ty.isUnsigned())
662668
rewriter.replaceOpWithNewOp<mlir::arith::DivUIOp>(
663669
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
664670
else
665-
llvm_unreachable("integer mlirType not supported in CIR yet");
671+
rewriter.replaceOpWithNewOp<mlir::arith::DivSIOp>(
672+
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
666673
} else
667674
rewriter.replaceOpWithNewOp<mlir::arith::DivFOp>(
668675
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
669676
break;
670677
case mlir::cir::BinOpKind::Rem:
671-
if (mlirType.isa<mlir::IntegerType>()) {
672-
if (mlirType.isSignlessInteger())
678+
if (auto ty = type.dyn_cast<mlir::cir::IntType>()) {
679+
if (ty.isUnsigned())
673680
rewriter.replaceOpWithNewOp<mlir::arith::RemUIOp>(
674681
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
675682
else
676-
llvm_unreachable("integer mlirType not supported in CIR yet");
683+
rewriter.replaceOpWithNewOp<mlir::arith::RemSIOp>(
684+
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
677685
} else
678686
rewriter.replaceOpWithNewOp<mlir::arith::RemFOp>(
679687
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
@@ -703,144 +711,22 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<mlir::cir::CmpOp> {
703711
mlir::LogicalResult
704712
matchAndRewrite(mlir::cir::CmpOp op, OpAdaptor adaptor,
705713
mlir::ConversionPatternRewriter &rewriter) const override {
706-
auto type = adaptor.getLhs().getType();
707-
auto integerType =
708-
mlir::IntegerType::get(getContext(), 1, mlir::IntegerType::Signless);
714+
auto type = op.getLhs().getType();
709715

710716
mlir::Value mlirResult;
711-
switch (op.getKind()) {
712-
case mlir::cir::CmpOpKind::gt: {
713-
if (type.isa<mlir::IntegerType>()) {
714-
mlir::arith::CmpIPredicate cmpIType;
715-
if (!type.isSignlessInteger())
716-
llvm_unreachable("integer type not supported in CIR yet");
717-
cmpIType = mlir::arith::CmpIPredicate::ugt;
718-
mlirResult = rewriter.create<mlir::arith::CmpIOp>(
719-
op.getLoc(), integerType,
720-
mlir::arith::CmpIPredicateAttr::get(getContext(), cmpIType),
721-
adaptor.getLhs(), adaptor.getRhs());
722-
} else if (type.isa<mlir::FloatType>()) {
723-
mlirResult = rewriter.create<mlir::arith::CmpFOp>(
724-
op.getLoc(), integerType,
725-
mlir::arith::CmpFPredicateAttr::get(
726-
getContext(), mlir::arith::CmpFPredicate::UGT),
727-
adaptor.getLhs(), adaptor.getRhs(),
728-
mlir::arith::FastMathFlagsAttr::get(
729-
getContext(), mlir::arith::FastMathFlags::none));
730-
} else {
731-
llvm_unreachable("Unknown Operand Type");
732-
}
733-
break;
734-
}
735-
case mlir::cir::CmpOpKind::ge: {
736-
if (type.isa<mlir::IntegerType>()) {
737-
mlir::arith::CmpIPredicate cmpIType;
738-
if (!type.isSignlessInteger())
739-
llvm_unreachable("integer type not supported in CIR yet");
740-
cmpIType = mlir::arith::CmpIPredicate::uge;
741-
mlirResult = rewriter.create<mlir::arith::CmpIOp>(
742-
op.getLoc(), integerType,
743-
mlir::arith::CmpIPredicateAttr::get(getContext(), cmpIType),
744-
adaptor.getLhs(), adaptor.getRhs());
745-
} else if (type.isa<mlir::FloatType>()) {
746-
mlirResult = rewriter.create<mlir::arith::CmpFOp>(
747-
op.getLoc(), integerType,
748-
mlir::arith::CmpFPredicateAttr::get(
749-
getContext(), mlir::arith::CmpFPredicate::UGE),
750-
adaptor.getLhs(), adaptor.getRhs(),
751-
mlir::arith::FastMathFlagsAttr::get(
752-
getContext(), mlir::arith::FastMathFlags::none));
753-
} else {
754-
llvm_unreachable("Unknown Operand Type");
755-
}
756-
break;
757-
}
758-
case mlir::cir::CmpOpKind::lt: {
759-
if (type.isa<mlir::IntegerType>()) {
760-
mlir::arith::CmpIPredicate cmpIType;
761-
if (!type.isSignlessInteger())
762-
llvm_unreachable("integer type not supported in CIR yet");
763-
cmpIType = mlir::arith::CmpIPredicate::ult;
764-
mlirResult = rewriter.create<mlir::arith::CmpIOp>(
765-
op.getLoc(), integerType,
766-
mlir::arith::CmpIPredicateAttr::get(getContext(), cmpIType),
767-
adaptor.getLhs(), adaptor.getRhs());
768-
} else if (type.isa<mlir::FloatType>()) {
769-
mlirResult = rewriter.create<mlir::arith::CmpFOp>(
770-
op.getLoc(), integerType,
771-
mlir::arith::CmpFPredicateAttr::get(
772-
getContext(), mlir::arith::CmpFPredicate::ULT),
773-
adaptor.getLhs(), adaptor.getRhs(),
774-
mlir::arith::FastMathFlagsAttr::get(
775-
getContext(), mlir::arith::FastMathFlags::none));
776-
} else {
777-
llvm_unreachable("Unknown Operand Type");
778-
}
779-
break;
780-
}
781-
case mlir::cir::CmpOpKind::le: {
782-
if (type.isa<mlir::IntegerType>()) {
783-
mlir::arith::CmpIPredicate cmpIType;
784-
if (!type.isSignlessInteger())
785-
llvm_unreachable("integer type not supported in CIR yet");
786-
cmpIType = mlir::arith::CmpIPredicate::ule;
787-
mlirResult = rewriter.create<mlir::arith::CmpIOp>(
788-
op.getLoc(), integerType,
789-
mlir::arith::CmpIPredicateAttr::get(getContext(), cmpIType),
790-
adaptor.getLhs(), adaptor.getRhs());
791-
} else if (type.isa<mlir::FloatType>()) {
792-
mlirResult = rewriter.create<mlir::arith::CmpFOp>(
793-
op.getLoc(), integerType,
794-
mlir::arith::CmpFPredicateAttr::get(
795-
getContext(), mlir::arith::CmpFPredicate::ULE),
796-
adaptor.getLhs(), adaptor.getRhs(),
797-
mlir::arith::FastMathFlagsAttr::get(
798-
getContext(), mlir::arith::FastMathFlags::none));
799-
} else {
800-
llvm_unreachable("Unknown Operand Type");
801-
}
802-
break;
803-
}
804-
case mlir::cir::CmpOpKind::eq: {
805-
if (type.isa<mlir::IntegerType>()) {
806-
mlirResult = rewriter.create<mlir::arith::CmpIOp>(
807-
op.getLoc(), integerType,
808-
mlir::arith::CmpIPredicateAttr::get(getContext(),
809-
mlir::arith::CmpIPredicate::eq),
810-
adaptor.getLhs(), adaptor.getRhs());
811-
} else if (type.isa<mlir::FloatType>()) {
812-
mlirResult = rewriter.create<mlir::arith::CmpFOp>(
813-
op.getLoc(), integerType,
814-
mlir::arith::CmpFPredicateAttr::get(
815-
getContext(), mlir::arith::CmpFPredicate::UEQ),
816-
adaptor.getLhs(), adaptor.getRhs(),
817-
mlir::arith::FastMathFlagsAttr::get(
818-
getContext(), mlir::arith::FastMathFlags::none));
819-
} else {
820-
llvm_unreachable("Unknown Operand Type");
821-
}
822-
break;
823-
}
824-
case mlir::cir::CmpOpKind::ne: {
825-
if (type.isa<mlir::IntegerType>()) {
826-
mlirResult = rewriter.create<mlir::arith::CmpIOp>(
827-
op.getLoc(), integerType,
828-
mlir::arith::CmpIPredicateAttr::get(getContext(),
829-
mlir::arith::CmpIPredicate::ne),
830-
adaptor.getLhs(), adaptor.getRhs());
831-
} else if (type.isa<mlir::FloatType>()) {
832-
mlirResult = rewriter.create<mlir::arith::CmpFOp>(
833-
op.getLoc(), integerType,
834-
mlir::arith::CmpFPredicateAttr::get(
835-
getContext(), mlir::arith::CmpFPredicate::UNE),
836-
adaptor.getLhs(), adaptor.getRhs(),
837-
mlir::arith::FastMathFlagsAttr::get(
838-
getContext(), mlir::arith::FastMathFlags::none));
839-
} else {
840-
llvm_unreachable("Unknown Operand Type");
841-
}
842-
break;
843-
}
717+
718+
if (auto ty = type.dyn_cast<mlir::cir::IntType>()) {
719+
auto kind = convertCmpKindToCmpIPredicate(op.getKind(), ty.isSigned());
720+
mlirResult = rewriter.create<mlir::arith::CmpIOp>(
721+
op.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
722+
} else if (auto ty = type.dyn_cast<mlir::cir::CIRFPTypeInterface>()) {
723+
auto kind = convertCmpKindToCmpFPredicate(op.getKind());
724+
mlirResult = rewriter.create<mlir::arith::CmpFOp>(
725+
op.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
726+
} else if (auto ty = type.dyn_cast<mlir::cir::PointerType>()) {
727+
llvm_unreachable("pointer comparison not supported yet");
728+
} else {
729+
return op.emitError() << "unsupported type for CmpOp: " << type;
844730
}
845731

846732
// MLIR comparison ops return i1, but cir::CmpOp returns the same type as
@@ -1143,6 +1029,39 @@ class CIRVectorExtractLowering
11431029
}
11441030
};
11451031

1032+
class CIRVectorCmpOpLowering
1033+
: public mlir::OpConversionPattern<mlir::cir::VecCmpOp> {
1034+
public:
1035+
using OpConversionPattern<mlir::cir::VecCmpOp>::OpConversionPattern;
1036+
1037+
mlir::LogicalResult
1038+
matchAndRewrite(mlir::cir::VecCmpOp op, OpAdaptor adaptor,
1039+
mlir::ConversionPatternRewriter &rewriter) const override {
1040+
assert(op.getType().isa<mlir::cir::VectorType>() &&
1041+
op.getLhs().getType().isa<mlir::cir::VectorType>() &&
1042+
op.getRhs().getType().isa<mlir::cir::VectorType>() &&
1043+
"Vector compare with non-vector type");
1044+
auto elementType =
1045+
op.getLhs().getType().cast<mlir::cir::VectorType>().getEltType();
1046+
mlir::Value bitResult;
1047+
if (auto intType = elementType.dyn_cast<mlir::cir::IntType>()) {
1048+
bitResult = rewriter.create<mlir::arith::CmpIOp>(
1049+
op.getLoc(),
1050+
convertCmpKindToCmpIPredicate(op.getKind(), intType.isSigned()),
1051+
adaptor.getLhs(), adaptor.getRhs());
1052+
} else if (elementType.isa<mlir::cir::CIRFPTypeInterface>()) {
1053+
bitResult = rewriter.create<mlir::arith::CmpFOp>(
1054+
op.getLoc(), convertCmpKindToCmpFPredicate(op.getKind()),
1055+
adaptor.getLhs(), adaptor.getRhs());
1056+
} else {
1057+
return op.emitError() << "unsupported type for VecCmpOp: " << elementType;
1058+
}
1059+
rewriter.replaceOpWithNewOp<mlir::arith::ExtSIOp>(
1060+
op, typeConverter->convertType(op.getType()), bitResult);
1061+
return mlir::success();
1062+
}
1063+
};
1064+
11461065
class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
11471066
public:
11481067
using OpConversionPattern<mlir::cir::CastOp>::OpConversionPattern;
@@ -1345,22 +1264,22 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
13451264
mlir::TypeConverter &converter) {
13461265
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
13471266

1348-
patterns
1349-
.add<CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
1350-
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
1351-
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
1352-
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
1353-
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
1354-
CIRGetGlobalOpLowering, CIRCastOpLowering, CIRPtrStrideOpLowering,
1355-
CIRSqrtOpLowering, CIRCeilOpLowering, CIRExp2OpLowering,
1356-
CIRExpOpLowering, CIRFAbsOpLowering, CIRFloorOpLowering,
1357-
CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1358-
CIRRoundOpLowering, CIRPtrStrideOpLowering, CIRSinOpLowering,
1359-
CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
1360-
CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
1361-
CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering,
1362-
CIRVectorInsertLowering, CIRVectorExtractLowering>(
1363-
converter, patterns.getContext());
1267+
patterns.add<
1268+
CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
1269+
CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
1270+
CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering,
1271+
CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering,
1272+
CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
1273+
CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering,
1274+
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1275+
CIRFloorOpLowering, CIRLog10OpLowering, CIRLog2OpLowering,
1276+
CIRLogOpLowering, CIRRoundOpLowering, CIRPtrStrideOpLowering,
1277+
CIRSinOpLowering, CIRShiftOpLowering, CIRBitClzOpLowering,
1278+
CIRBitCtzOpLowering, CIRBitPopcountOpLowering, CIRBitClrsbOpLowering,
1279+
CIRBitFfsOpLowering, CIRBitParityOpLowering, CIRIfOpLowering,
1280+
CIRVectorCreateLowering, CIRVectorInsertLowering,
1281+
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(converter,
1282+
patterns.getContext());
13641283
}
13651284

13661285
static mlir::TypeConverter prepareTypeConverter() {

clang/lib/CIR/Lowering/ThroughMLIR/LowerToMLIRHelpers.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "mlir/IR/BuiltinAttributes.h"
55
#include "mlir/IR/BuiltinTypes.h"
66
#include "mlir/Transforms/DialectConversion.h"
7+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
78

89
template <typename T>
910
mlir::Value getConst(mlir::ConversionPatternRewriter &rewriter,
@@ -37,4 +38,46 @@ mlir::Value createIntCast(mlir::ConversionPatternRewriter &rewriter,
3738
return rewriter.create<mlir::arith::BitcastOp>(loc, dstTy, src);
3839
}
3940

41+
mlir::arith::CmpIPredicate
42+
convertCmpKindToCmpIPredicate(mlir::cir::CmpOpKind kind, bool isSigned) {
43+
using CIR = mlir::cir::CmpOpKind;
44+
using arithCmpI = mlir::arith::CmpIPredicate;
45+
switch (kind) {
46+
case CIR::eq:
47+
return arithCmpI::eq;
48+
case CIR::ne:
49+
return arithCmpI::ne;
50+
case CIR::lt:
51+
return (isSigned ? arithCmpI::slt : arithCmpI::ult);
52+
case CIR::le:
53+
return (isSigned ? arithCmpI::sle : arithCmpI::ule);
54+
case CIR::gt:
55+
return (isSigned ? arithCmpI::sgt : arithCmpI::ugt);
56+
case CIR::ge:
57+
return (isSigned ? arithCmpI::sge : arithCmpI::uge);
58+
}
59+
llvm_unreachable("Unknown CmpOpKind");
60+
}
61+
62+
mlir::arith::CmpFPredicate
63+
convertCmpKindToCmpFPredicate(mlir::cir::CmpOpKind kind) {
64+
using CIR = mlir::cir::CmpOpKind;
65+
using arithCmpF = mlir::arith::CmpFPredicate;
66+
switch (kind) {
67+
case CIR::eq:
68+
return arithCmpF::OEQ;
69+
case CIR::ne:
70+
return arithCmpF::UNE;
71+
case CIR::lt:
72+
return arithCmpF::OLT;
73+
case CIR::le:
74+
return arithCmpF::OLE;
75+
case CIR::gt:
76+
return arithCmpF::OGT;
77+
case CIR::ge:
78+
return arithCmpF::OGE;
79+
}
80+
llvm_unreachable("Unknown CmpOpKind");
81+
}
82+
4083
#endif

0 commit comments

Comments
 (0)