@@ -628,52 +628,60 @@ class CIRBinOpLowering : public mlir::OpConversionPattern<mlir::cir::BinOp> {
628
628
" inconsistent operands' types not supported yet" );
629
629
mlir::Type mlirType = getTypeConverter ()->convertType (op.getType ());
630
630
assert ((mlirType.isa <mlir::IntegerType>() ||
631
- mlirType.isa <mlir::FloatType>()) &&
631
+ mlirType.isa <mlir::FloatType>() ||
632
+ mlirType.isa <mlir::VectorType>()) &&
632
633
" operand type not supported yet" );
633
634
635
+ auto type = op.getLhs ().getType ();
636
+ if (auto VecType = type.dyn_cast <mlir::cir::VectorType>()) {
637
+ type = VecType.getEltType ();
638
+ }
639
+
634
640
switch (op.getKind ()) {
635
641
case mlir::cir::BinOpKind::Add:
636
- if (mlirType .isa <mlir::IntegerType >())
642
+ if (type .isa <mlir::cir::IntType >())
637
643
rewriter.replaceOpWithNewOp <mlir::arith::AddIOp>(
638
644
op, mlirType, adaptor.getLhs (), adaptor.getRhs ());
639
645
else
640
646
rewriter.replaceOpWithNewOp <mlir::arith::AddFOp>(
641
647
op, mlirType, adaptor.getLhs (), adaptor.getRhs ());
642
648
break ;
643
649
case mlir::cir::BinOpKind::Sub:
644
- if (mlirType .isa <mlir::IntegerType >())
650
+ if (type .isa <mlir::cir::IntType >())
645
651
rewriter.replaceOpWithNewOp <mlir::arith::SubIOp>(
646
652
op, mlirType, adaptor.getLhs (), adaptor.getRhs ());
647
653
else
648
654
rewriter.replaceOpWithNewOp <mlir::arith::SubFOp>(
649
655
op, mlirType, adaptor.getLhs (), adaptor.getRhs ());
650
656
break ;
651
657
case mlir::cir::BinOpKind::Mul:
652
- if (mlirType .isa <mlir::IntegerType >())
658
+ if (type .isa <mlir::cir::IntType >())
653
659
rewriter.replaceOpWithNewOp <mlir::arith::MulIOp>(
654
660
op, mlirType, adaptor.getLhs (), adaptor.getRhs ());
655
661
else
656
662
rewriter.replaceOpWithNewOp <mlir::arith::MulFOp>(
657
663
op, mlirType, adaptor.getLhs (), adaptor.getRhs ());
658
664
break ;
659
665
case mlir::cir::BinOpKind::Div:
660
- if (mlirType. isa <mlir::IntegerType >()) {
661
- if (mlirType. isSignlessInteger ())
666
+ if (auto ty = type. dyn_cast <mlir::cir::IntType >()) {
667
+ if (ty. isUnsigned ())
662
668
rewriter.replaceOpWithNewOp <mlir::arith::DivUIOp>(
663
669
op, mlirType, adaptor.getLhs (), adaptor.getRhs ());
664
670
else
665
- llvm_unreachable (" integer mlirType not supported in CIR yet" );
671
+ rewriter.replaceOpWithNewOp <mlir::arith::DivSIOp>(
672
+ op, mlirType, adaptor.getLhs (), adaptor.getRhs ());
666
673
} else
667
674
rewriter.replaceOpWithNewOp <mlir::arith::DivFOp>(
668
675
op, mlirType, adaptor.getLhs (), adaptor.getRhs ());
669
676
break ;
670
677
case mlir::cir::BinOpKind::Rem:
671
- if (mlirType. isa <mlir::IntegerType >()) {
672
- if (mlirType. isSignlessInteger ())
678
+ if (auto ty = type. dyn_cast <mlir::cir::IntType >()) {
679
+ if (ty. isUnsigned ())
673
680
rewriter.replaceOpWithNewOp <mlir::arith::RemUIOp>(
674
681
op, mlirType, adaptor.getLhs (), adaptor.getRhs ());
675
682
else
676
- llvm_unreachable (" integer mlirType not supported in CIR yet" );
683
+ rewriter.replaceOpWithNewOp <mlir::arith::RemSIOp>(
684
+ op, mlirType, adaptor.getLhs (), adaptor.getRhs ());
677
685
} else
678
686
rewriter.replaceOpWithNewOp <mlir::arith::RemFOp>(
679
687
op, mlirType, adaptor.getLhs (), adaptor.getRhs ());
@@ -703,144 +711,22 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<mlir::cir::CmpOp> {
703
711
mlir::LogicalResult
704
712
matchAndRewrite (mlir::cir::CmpOp op, OpAdaptor adaptor,
705
713
mlir::ConversionPatternRewriter &rewriter) const override {
706
- auto type = adaptor.getLhs ().getType ();
707
- auto integerType =
708
- mlir::IntegerType::get (getContext (), 1 , mlir::IntegerType::Signless);
714
+ auto type = op.getLhs ().getType ();
709
715
710
716
mlir::Value mlirResult;
711
- switch (op.getKind ()) {
712
- case mlir::cir::CmpOpKind::gt: {
713
- if (type.isa <mlir::IntegerType>()) {
714
- mlir::arith::CmpIPredicate cmpIType;
715
- if (!type.isSignlessInteger ())
716
- llvm_unreachable (" integer type not supported in CIR yet" );
717
- cmpIType = mlir::arith::CmpIPredicate::ugt;
718
- mlirResult = rewriter.create <mlir::arith::CmpIOp>(
719
- op.getLoc (), integerType,
720
- mlir::arith::CmpIPredicateAttr::get (getContext (), cmpIType),
721
- adaptor.getLhs (), adaptor.getRhs ());
722
- } else if (type.isa <mlir::FloatType>()) {
723
- mlirResult = rewriter.create <mlir::arith::CmpFOp>(
724
- op.getLoc (), integerType,
725
- mlir::arith::CmpFPredicateAttr::get (
726
- getContext (), mlir::arith::CmpFPredicate::UGT),
727
- adaptor.getLhs (), adaptor.getRhs (),
728
- mlir::arith::FastMathFlagsAttr::get (
729
- getContext (), mlir::arith::FastMathFlags::none));
730
- } else {
731
- llvm_unreachable (" Unknown Operand Type" );
732
- }
733
- break ;
734
- }
735
- case mlir::cir::CmpOpKind::ge: {
736
- if (type.isa <mlir::IntegerType>()) {
737
- mlir::arith::CmpIPredicate cmpIType;
738
- if (!type.isSignlessInteger ())
739
- llvm_unreachable (" integer type not supported in CIR yet" );
740
- cmpIType = mlir::arith::CmpIPredicate::uge;
741
- mlirResult = rewriter.create <mlir::arith::CmpIOp>(
742
- op.getLoc (), integerType,
743
- mlir::arith::CmpIPredicateAttr::get (getContext (), cmpIType),
744
- adaptor.getLhs (), adaptor.getRhs ());
745
- } else if (type.isa <mlir::FloatType>()) {
746
- mlirResult = rewriter.create <mlir::arith::CmpFOp>(
747
- op.getLoc (), integerType,
748
- mlir::arith::CmpFPredicateAttr::get (
749
- getContext (), mlir::arith::CmpFPredicate::UGE),
750
- adaptor.getLhs (), adaptor.getRhs (),
751
- mlir::arith::FastMathFlagsAttr::get (
752
- getContext (), mlir::arith::FastMathFlags::none));
753
- } else {
754
- llvm_unreachable (" Unknown Operand Type" );
755
- }
756
- break ;
757
- }
758
- case mlir::cir::CmpOpKind::lt: {
759
- if (type.isa <mlir::IntegerType>()) {
760
- mlir::arith::CmpIPredicate cmpIType;
761
- if (!type.isSignlessInteger ())
762
- llvm_unreachable (" integer type not supported in CIR yet" );
763
- cmpIType = mlir::arith::CmpIPredicate::ult;
764
- mlirResult = rewriter.create <mlir::arith::CmpIOp>(
765
- op.getLoc (), integerType,
766
- mlir::arith::CmpIPredicateAttr::get (getContext (), cmpIType),
767
- adaptor.getLhs (), adaptor.getRhs ());
768
- } else if (type.isa <mlir::FloatType>()) {
769
- mlirResult = rewriter.create <mlir::arith::CmpFOp>(
770
- op.getLoc (), integerType,
771
- mlir::arith::CmpFPredicateAttr::get (
772
- getContext (), mlir::arith::CmpFPredicate::ULT),
773
- adaptor.getLhs (), adaptor.getRhs (),
774
- mlir::arith::FastMathFlagsAttr::get (
775
- getContext (), mlir::arith::FastMathFlags::none));
776
- } else {
777
- llvm_unreachable (" Unknown Operand Type" );
778
- }
779
- break ;
780
- }
781
- case mlir::cir::CmpOpKind::le: {
782
- if (type.isa <mlir::IntegerType>()) {
783
- mlir::arith::CmpIPredicate cmpIType;
784
- if (!type.isSignlessInteger ())
785
- llvm_unreachable (" integer type not supported in CIR yet" );
786
- cmpIType = mlir::arith::CmpIPredicate::ule;
787
- mlirResult = rewriter.create <mlir::arith::CmpIOp>(
788
- op.getLoc (), integerType,
789
- mlir::arith::CmpIPredicateAttr::get (getContext (), cmpIType),
790
- adaptor.getLhs (), adaptor.getRhs ());
791
- } else if (type.isa <mlir::FloatType>()) {
792
- mlirResult = rewriter.create <mlir::arith::CmpFOp>(
793
- op.getLoc (), integerType,
794
- mlir::arith::CmpFPredicateAttr::get (
795
- getContext (), mlir::arith::CmpFPredicate::ULE),
796
- adaptor.getLhs (), adaptor.getRhs (),
797
- mlir::arith::FastMathFlagsAttr::get (
798
- getContext (), mlir::arith::FastMathFlags::none));
799
- } else {
800
- llvm_unreachable (" Unknown Operand Type" );
801
- }
802
- break ;
803
- }
804
- case mlir::cir::CmpOpKind::eq: {
805
- if (type.isa <mlir::IntegerType>()) {
806
- mlirResult = rewriter.create <mlir::arith::CmpIOp>(
807
- op.getLoc (), integerType,
808
- mlir::arith::CmpIPredicateAttr::get (getContext (),
809
- mlir::arith::CmpIPredicate::eq),
810
- adaptor.getLhs (), adaptor.getRhs ());
811
- } else if (type.isa <mlir::FloatType>()) {
812
- mlirResult = rewriter.create <mlir::arith::CmpFOp>(
813
- op.getLoc (), integerType,
814
- mlir::arith::CmpFPredicateAttr::get (
815
- getContext (), mlir::arith::CmpFPredicate::UEQ),
816
- adaptor.getLhs (), adaptor.getRhs (),
817
- mlir::arith::FastMathFlagsAttr::get (
818
- getContext (), mlir::arith::FastMathFlags::none));
819
- } else {
820
- llvm_unreachable (" Unknown Operand Type" );
821
- }
822
- break ;
823
- }
824
- case mlir::cir::CmpOpKind::ne: {
825
- if (type.isa <mlir::IntegerType>()) {
826
- mlirResult = rewriter.create <mlir::arith::CmpIOp>(
827
- op.getLoc (), integerType,
828
- mlir::arith::CmpIPredicateAttr::get (getContext (),
829
- mlir::arith::CmpIPredicate::ne),
830
- adaptor.getLhs (), adaptor.getRhs ());
831
- } else if (type.isa <mlir::FloatType>()) {
832
- mlirResult = rewriter.create <mlir::arith::CmpFOp>(
833
- op.getLoc (), integerType,
834
- mlir::arith::CmpFPredicateAttr::get (
835
- getContext (), mlir::arith::CmpFPredicate::UNE),
836
- adaptor.getLhs (), adaptor.getRhs (),
837
- mlir::arith::FastMathFlagsAttr::get (
838
- getContext (), mlir::arith::FastMathFlags::none));
839
- } else {
840
- llvm_unreachable (" Unknown Operand Type" );
841
- }
842
- break ;
843
- }
717
+
718
+ if (auto ty = type.dyn_cast <mlir::cir::IntType>()) {
719
+ auto kind = convertCmpKindToCmpIPredicate (op.getKind (), ty.isSigned ());
720
+ mlirResult = rewriter.create <mlir::arith::CmpIOp>(
721
+ op.getLoc (), kind, adaptor.getLhs (), adaptor.getRhs ());
722
+ } else if (auto ty = type.dyn_cast <mlir::cir::CIRFPTypeInterface>()) {
723
+ auto kind = convertCmpKindToCmpFPredicate (op.getKind ());
724
+ mlirResult = rewriter.create <mlir::arith::CmpFOp>(
725
+ op.getLoc (), kind, adaptor.getLhs (), adaptor.getRhs ());
726
+ } else if (auto ty = type.dyn_cast <mlir::cir::PointerType>()) {
727
+ llvm_unreachable (" pointer comparison not supported yet" );
728
+ } else {
729
+ return op.emitError () << " unsupported type for CmpOp: " << type;
844
730
}
845
731
846
732
// MLIR comparison ops return i1, but cir::CmpOp returns the same type as
@@ -1143,6 +1029,39 @@ class CIRVectorExtractLowering
1143
1029
}
1144
1030
};
1145
1031
1032
+ class CIRVectorCmpOpLowering
1033
+ : public mlir::OpConversionPattern<mlir::cir::VecCmpOp> {
1034
+ public:
1035
+ using OpConversionPattern<mlir::cir::VecCmpOp>::OpConversionPattern;
1036
+
1037
+ mlir::LogicalResult
1038
+ matchAndRewrite (mlir::cir::VecCmpOp op, OpAdaptor adaptor,
1039
+ mlir::ConversionPatternRewriter &rewriter) const override {
1040
+ assert (op.getType ().isa <mlir::cir::VectorType>() &&
1041
+ op.getLhs ().getType ().isa <mlir::cir::VectorType>() &&
1042
+ op.getRhs ().getType ().isa <mlir::cir::VectorType>() &&
1043
+ " Vector compare with non-vector type" );
1044
+ auto elementType =
1045
+ op.getLhs ().getType ().cast <mlir::cir::VectorType>().getEltType ();
1046
+ mlir::Value bitResult;
1047
+ if (auto intType = elementType.dyn_cast <mlir::cir::IntType>()) {
1048
+ bitResult = rewriter.create <mlir::arith::CmpIOp>(
1049
+ op.getLoc (),
1050
+ convertCmpKindToCmpIPredicate (op.getKind (), intType.isSigned ()),
1051
+ adaptor.getLhs (), adaptor.getRhs ());
1052
+ } else if (elementType.isa <mlir::cir::CIRFPTypeInterface>()) {
1053
+ bitResult = rewriter.create <mlir::arith::CmpFOp>(
1054
+ op.getLoc (), convertCmpKindToCmpFPredicate (op.getKind ()),
1055
+ adaptor.getLhs (), adaptor.getRhs ());
1056
+ } else {
1057
+ return op.emitError () << " unsupported type for VecCmpOp: " << elementType;
1058
+ }
1059
+ rewriter.replaceOpWithNewOp <mlir::arith::ExtSIOp>(
1060
+ op, typeConverter->convertType (op.getType ()), bitResult);
1061
+ return mlir::success ();
1062
+ }
1063
+ };
1064
+
1146
1065
class CIRCastOpLowering : public mlir ::OpConversionPattern<mlir::cir::CastOp> {
1147
1066
public:
1148
1067
using OpConversionPattern<mlir::cir::CastOp>::OpConversionPattern;
@@ -1345,22 +1264,22 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
1345
1264
mlir::TypeConverter &converter) {
1346
1265
patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
1347
1266
1348
- patterns
1349
- . add < CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
1350
- CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
1351
- CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
1352
- CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
1353
- CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
1354
- CIRGetGlobalOpLowering, CIRCastOpLowering, CIRPtrStrideOpLowering,
1355
- CIRSqrtOpLowering, CIRCeilOpLowering, CIRExp2OpLowering ,
1356
- CIRExpOpLowering, CIRFAbsOpLowering, CIRFloorOpLowering ,
1357
- CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering ,
1358
- CIRRoundOpLowering, CIRPtrStrideOpLowering, CIRSinOpLowering ,
1359
- CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering ,
1360
- CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering ,
1361
- CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering ,
1362
- CIRVectorInsertLowering, CIRVectorExtractLowering>(
1363
- converter, patterns.getContext ());
1267
+ patterns. add <
1268
+ CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering ,
1269
+ CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering ,
1270
+ CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering ,
1271
+ CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering ,
1272
+ CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering ,
1273
+ CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering ,
1274
+ CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering ,
1275
+ CIRFloorOpLowering, CIRLog10OpLowering, CIRLog2OpLowering ,
1276
+ CIRLogOpLowering, CIRRoundOpLowering, CIRPtrStrideOpLowering ,
1277
+ CIRSinOpLowering, CIRShiftOpLowering, CIRBitClzOpLowering ,
1278
+ CIRBitCtzOpLowering, CIRBitPopcountOpLowering, CIRBitClrsbOpLowering ,
1279
+ CIRBitFfsOpLowering, CIRBitParityOpLowering, CIRIfOpLowering ,
1280
+ CIRVectorCreateLowering, CIRVectorInsertLowering ,
1281
+ CIRVectorExtractLowering, CIRVectorCmpOpLowering>(converter,
1282
+ patterns.getContext ());
1364
1283
}
1365
1284
1366
1285
static mlir::TypeConverter prepareTypeConverter () {
0 commit comments