From 1dcbd27118c753e9a19947fad7ff08430e26894a Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern@gmail.com>
Date: Fri, 16 Aug 2024 01:47:03 +0800
Subject: [PATCH] [CIR][CIRGen] Complex unary increment and decrement operator
 (#790)

This PR adds CIRGen and LLVMIR lowering for unary increment and
decrement expressions of complex types.

Currently blocked by #789 .
---
 clang/lib/CIR/CodeGen/CIRGenExpr.cpp          |   2 +-
 clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp   |  34 ++++-
 clang/lib/CIR/CodeGen/CIRGenFunction.h        |   5 +
 .../Dialect/Transforms/LoweringPrepare.cpp    |  10 +-
 clang/test/CIR/CodeGen/complex-arithmetic.c   | 140 ++++++++++++++++++
 5 files changed, 183 insertions(+), 8 deletions(-)

diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
index eaed77193b4e..eab4ac3c9cab 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
@@ -1321,7 +1321,7 @@ LValue CIRGenFunction::buildUnaryOpLValue(const UnaryOperator *E) {
     LValue LV = buildLValue(E->getSubExpr());
 
     if (E->getType()->isAnyComplexType()) {
-      assert(0 && "not implemented");
+      buildComplexPrePostIncDec(E, LV, isInc, true /*isPre*/);
     } else {
       buildScalarPrePostIncDec(E, LV, isInc, isPre);
     }
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp b/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
index 33e3b67f8082..7472b039649b 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
@@ -134,9 +134,7 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {
 
   // Operators.
   mlir::Value VisitPrePostIncDec(const UnaryOperator *E, bool isInc,
-                                 bool isPre) {
-    llvm_unreachable("NYI");
-  }
+                                 bool isPre);
   mlir::Value VisitUnaryPostDec(const UnaryOperator *E) {
     return VisitPrePostIncDec(E, false, false);
   }
@@ -524,6 +522,12 @@ mlir::Value ComplexExprEmitter::VisitCallExpr(const CallExpr *E) {
   return CGF.buildCallExpr(E).getComplexVal();
 }
 
+mlir::Value ComplexExprEmitter::VisitPrePostIncDec(const UnaryOperator *E,
+                                                   bool isInc, bool isPre) {
+  LValue LV = CGF.buildLValue(E->getSubExpr());
+  return CGF.buildComplexPrePostIncDec(E, LV, isInc, isPre);
+}
+
 mlir::Value ComplexExprEmitter::VisitUnaryPlus(const UnaryOperator *E,
                                                QualType PromotionType) {
   QualType promotionTy = PromotionType.isNull()
@@ -956,3 +960,27 @@ LValue CIRGenFunction::buildComplexCompoundAssignmentLValue(
   RValue Val;
   return ComplexExprEmitter(*this).buildCompoundAssignLValue(E, Op, Val);
 }
+
+mlir::Value CIRGenFunction::buildComplexPrePostIncDec(const UnaryOperator *E,
+                                                      LValue LV, bool isInc,
+                                                      bool isPre) {
+  mlir::Value InVal = buildLoadOfComplex(LV, E->getExprLoc());
+
+  auto Loc = getLoc(E->getExprLoc());
+  auto OpKind =
+      isInc ? mlir::cir::UnaryOpKind::Inc : mlir::cir::UnaryOpKind::Dec;
+  mlir::Value IncVal = builder.createUnaryOp(Loc, OpKind, InVal);
+
+  // Store the updated result through the lvalue.
+  buildStoreOfComplex(Loc, IncVal, LV, /*init*/ false);
+  if (getLangOpts().OpenMP)
+    llvm_unreachable("NYI");
+
+  // If this is a postinc, return the value read from memory, otherwise use the
+  // updated value.
+  return isPre ? IncVal : InVal;
+}
+
+mlir::Value CIRGenFunction::buildLoadOfComplex(LValue src, SourceLocation loc) {
+  return ComplexExprEmitter(*this).buildLoadOfLValue(src, loc);
+}
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h
index 25a74293a19c..a3ca8857e8da 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunction.h
+++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h
@@ -666,6 +666,8 @@ class CIRGenFunction : public CIRGenTypeCache {
 
   mlir::Value buildScalarPrePostIncDec(const UnaryOperator *E, LValue LV,
                                        bool isInc, bool isPre);
+  mlir::Value buildComplexPrePostIncDec(const UnaryOperator *E, LValue LV,
+                                        bool isInc, bool isPre);
 
   // Wrapper for function prototype sources. Wraps either a FunctionProtoType or
   // an ObjCMethodDecl.
@@ -799,6 +801,9 @@ class CIRGenFunction : public CIRGenTypeCache {
   mlir::Value buildLoadOfScalar(LValue lvalue, clang::SourceLocation Loc);
   mlir::Value buildLoadOfScalar(LValue lvalue, mlir::Location Loc);
 
+  /// Load a complex number from the specified l-value.
+  mlir::Value buildLoadOfComplex(LValue src, SourceLocation loc);
+
   Address buildLoadOfReference(LValue RefLVal, mlir::Location Loc,
                                LValueBaseInfo *PointeeBaseInfo = nullptr);
   LValue buildLoadOfReferenceLValue(LValue RefLVal, mlir::Location Loc);
diff --git a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
index c4244b1b2e8f..ca6970f04830 100644
--- a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
@@ -356,10 +356,6 @@ void LoweringPreparePass::lowerUnaryOp(UnaryOp op) {
 
   auto loc = op.getLoc();
   auto opKind = op.getKind();
-  assert((opKind == mlir::cir::UnaryOpKind::Plus ||
-          opKind == mlir::cir::UnaryOpKind::Minus ||
-          opKind == mlir::cir::UnaryOpKind::Not) &&
-         "invalid unary op kind on complex numbers");
 
   CIRBaseBuilderTy builder(getContext());
   builder.setInsertionPointAfter(op);
@@ -372,6 +368,12 @@ void LoweringPreparePass::lowerUnaryOp(UnaryOp op) {
   mlir::Value resultReal;
   mlir::Value resultImag;
   switch (opKind) {
+  case mlir::cir::UnaryOpKind::Inc:
+  case mlir::cir::UnaryOpKind::Dec:
+    resultReal = builder.createUnaryOp(loc, opKind, operandReal);
+    resultImag = operandImag;
+    break;
+
   case mlir::cir::UnaryOpKind::Plus:
   case mlir::cir::UnaryOpKind::Minus:
     resultReal = builder.createUnaryOp(loc, opKind, operandReal);
diff --git a/clang/test/CIR/CodeGen/complex-arithmetic.c b/clang/test/CIR/CodeGen/complex-arithmetic.c
index f7b85000ce6b..c2e86ca43f74 100644
--- a/clang/test/CIR/CodeGen/complex-arithmetic.c
+++ b/clang/test/CIR/CodeGen/complex-arithmetic.c
@@ -776,3 +776,143 @@ void builtin_conj() {
 // LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#RESI]], 1
 
 // CHECK: }
+
+void pre_increment() {
+  ++cd1;
+  ++ci1;
+}
+
+//   CLANG: @pre_increment
+// CPPLANG: @_Z13pre_incrementv
+
+// CIRGEN: %{{.+}} = cir.unary(inc, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
+// CIRGEN: %{{.+}} = cir.unary(inc, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>
+
+//      CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
+// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
+// CIR-NEXT: %[[#IR:]] = cir.unary(inc, %[[#R]]) : !cir.double, !cir.double
+// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !cir.double -> !cir.complex<!cir.double>
+
+//      CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
+// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
+// CIR-NEXT: %[[#IR:]] = cir.unary(inc, %[[#R]]) : !s32i, !s32i
+// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !s32i -> !cir.complex<!s32i>
+
+//      LLVM: %[[#R:]] = extractvalue { double, double } %{{.+}}, 0
+// LLVM-NEXT: %[[#I:]] = extractvalue { double, double } %{{.+}}, 1
+// LLVM-NEXT: %[[#IR:]] = fadd double 1.000000e+00, %[[#R]]
+// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#IR]], 0
+// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#I]], 1
+
+//      LLVM: %[[#R:]] = extractvalue { i32, i32 } %{{.+}}, 0
+// LLVM-NEXT: %[[#I:]] = extractvalue { i32, i32 } %{{.+}}, 1
+// LLVM-NEXT: %[[#IR:]] = add i32 %[[#R]], 1
+// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#IR]], 0
+// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#I]], 1
+
+// CHECK: }
+
+void post_increment() {
+  cd1++;
+  ci1++;
+}
+
+//   CLANG: @post_increment
+// CPPLANG: @_Z14post_incrementv
+
+// CIRGEN: %{{.+}} = cir.unary(inc, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
+// CIRGEN: %{{.+}} = cir.unary(inc, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>
+
+//      CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
+// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
+// CIR-NEXT: %[[#IR:]] = cir.unary(inc, %[[#R]]) : !cir.double, !cir.double
+// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !cir.double -> !cir.complex<!cir.double>
+
+//      CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
+// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
+// CIR-NEXT: %[[#IR:]] = cir.unary(inc, %[[#R]]) : !s32i, !s32i
+// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !s32i -> !cir.complex<!s32i>
+
+//      LLVM: %[[#R:]] = extractvalue { double, double } %{{.+}}, 0
+// LLVM-NEXT: %[[#I:]] = extractvalue { double, double } %{{.+}}, 1
+// LLVM-NEXT: %[[#IR:]] = fadd double 1.000000e+00, %[[#R]]
+// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#IR]], 0
+// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#I]], 1
+
+//      LLVM: %[[#R:]] = extractvalue { i32, i32 } %{{.+}}, 0
+// LLVM-NEXT: %[[#I:]] = extractvalue { i32, i32 } %{{.+}}, 1
+// LLVM-NEXT: %[[#IR:]] = add i32 %[[#R]], 1
+// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#IR]], 0
+// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#I]], 1
+
+// CHECK: }
+
+void pre_decrement() {
+  --cd1;
+  --ci1;
+}
+
+//   CLANG: @pre_decrement
+// CPPLANG: @_Z13pre_decrementv
+
+// CIRGEN: %{{.+}} = cir.unary(dec, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
+// CIRGEN: %{{.+}} = cir.unary(dec, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>
+
+//      CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
+// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
+// CIR-NEXT: %[[#IR:]] = cir.unary(dec, %[[#R]]) : !cir.double, !cir.double
+// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !cir.double -> !cir.complex<!cir.double>
+
+//      CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
+// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
+// CIR-NEXT: %[[#IR:]] = cir.unary(dec, %[[#R]]) : !s32i, !s32i
+// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !s32i -> !cir.complex<!s32i>
+
+//      LLVM: %[[#R:]] = extractvalue { double, double } %{{.+}}, 0
+// LLVM-NEXT: %[[#I:]] = extractvalue { double, double } %{{.+}}, 1
+// LLVM-NEXT: %[[#IR:]] = fadd double -1.000000e+00, %[[#R]]
+// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#IR]], 0
+// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#I]], 1
+
+//      LLVM: %[[#R:]] = extractvalue { i32, i32 } %{{.+}}, 0
+// LLVM-NEXT: %[[#I:]] = extractvalue { i32, i32 } %{{.+}}, 1
+// LLVM-NEXT: %[[#IR:]] = sub i32 %[[#R]], 1
+// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#IR]], 0
+// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#I]], 1
+
+// CHECK: }
+
+void post_decrement() {
+  cd1--;
+  ci1--;
+}
+
+//   CLANG: @post_decrement
+// CPPLANG: @_Z14post_decrementv
+
+// CIRGEN: %{{.+}} = cir.unary(dec, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
+// CIRGEN: %{{.+}} = cir.unary(dec, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>
+
+//      CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
+// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
+// CIR-NEXT: %[[#IR:]] = cir.unary(dec, %[[#R]]) : !cir.double, !cir.double
+// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !cir.double -> !cir.complex<!cir.double>
+
+//      CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
+// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
+// CIR-NEXT: %[[#IR:]] = cir.unary(dec, %[[#R]]) : !s32i, !s32i
+// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !s32i -> !cir.complex<!s32i>
+
+//      LLVM: %[[#R:]] = extractvalue { double, double } %{{.+}}, 0
+// LLVM-NEXT: %[[#I:]] = extractvalue { double, double } %{{.+}}, 1
+// LLVM-NEXT: %[[#IR:]] = fadd double -1.000000e+00, %[[#R]]
+// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#IR]], 0
+// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#I]], 1
+
+//      LLVM: %[[#R:]] = extractvalue { i32, i32 } %{{.+}}, 0
+// LLVM-NEXT: %[[#I:]] = extractvalue { i32, i32 } %{{.+}}, 1
+// LLVM-NEXT: %[[#IR:]] = sub i32 %[[#R]], 1
+// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#IR]], 0
+// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#I]], 1
+
+// CHECK: }