Skip to content

Commit

Permalink
[CIR][CIRGen] Complex unary increment and decrement operator (#790)
Browse files Browse the repository at this point in the history
This PR adds CIRGen and LLVMIR lowering for unary increment and
decrement expressions of complex types.

Currently blocked by #789 .
  • Loading branch information
Lancern authored Aug 15, 2024
1 parent 5ebbd52 commit c1900b7
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 8 deletions.
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,7 @@ LValue CIRGenFunction::buildUnaryOpLValue(const UnaryOperator *E) {
LValue LV = buildLValue(E->getSubExpr());

if (E->getType()->isAnyComplexType()) {
assert(0 && "not implemented");
buildComplexPrePostIncDec(E, LV, isInc, true /*isPre*/);
} else {
buildScalarPrePostIncDec(E, LV, isInc, isPre);
}
Expand Down
34 changes: 31 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,7 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {

// Operators.
mlir::Value VisitPrePostIncDec(const UnaryOperator *E, bool isInc,
bool isPre) {
llvm_unreachable("NYI");
}
bool isPre);
mlir::Value VisitUnaryPostDec(const UnaryOperator *E) {
return VisitPrePostIncDec(E, false, false);
}
Expand Down Expand Up @@ -537,6 +535,12 @@ mlir::Value ComplexExprEmitter::VisitCallExpr(const CallExpr *E) {
return CGF.buildCallExpr(E).getComplexVal();
}

mlir::Value ComplexExprEmitter::VisitPrePostIncDec(const UnaryOperator *E,
bool isInc, bool isPre) {
LValue LV = CGF.buildLValue(E->getSubExpr());
return CGF.buildComplexPrePostIncDec(E, LV, isInc, isPre);
}

mlir::Value ComplexExprEmitter::VisitUnaryPlus(const UnaryOperator *E,
QualType PromotionType) {
QualType promotionTy = PromotionType.isNull()
Expand Down Expand Up @@ -969,3 +973,27 @@ LValue CIRGenFunction::buildComplexCompoundAssignmentLValue(
RValue Val;
return ComplexExprEmitter(*this).buildCompoundAssignLValue(E, Op, Val);
}

mlir::Value CIRGenFunction::buildComplexPrePostIncDec(const UnaryOperator *E,
LValue LV, bool isInc,
bool isPre) {
mlir::Value InVal = buildLoadOfComplex(LV, E->getExprLoc());

auto Loc = getLoc(E->getExprLoc());
auto OpKind =
isInc ? mlir::cir::UnaryOpKind::Inc : mlir::cir::UnaryOpKind::Dec;
mlir::Value IncVal = builder.createUnaryOp(Loc, OpKind, InVal);

// Store the updated result through the lvalue.
buildStoreOfComplex(Loc, IncVal, LV, /*init*/ false);
if (getLangOpts().OpenMP)
llvm_unreachable("NYI");

// If this is a postinc, return the value read from memory, otherwise use the
// updated value.
return isPre ? IncVal : InVal;
}

mlir::Value CIRGenFunction::buildLoadOfComplex(LValue src, SourceLocation loc) {
return ComplexExprEmitter(*this).buildLoadOfLValue(src, loc);
}
5 changes: 5 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,8 @@ class CIRGenFunction : public CIRGenTypeCache {

mlir::Value buildScalarPrePostIncDec(const UnaryOperator *E, LValue LV,
bool isInc, bool isPre);
mlir::Value buildComplexPrePostIncDec(const UnaryOperator *E, LValue LV,
bool isInc, bool isPre);

// Wrapper for function prototype sources. Wraps either a FunctionProtoType or
// an ObjCMethodDecl.
Expand Down Expand Up @@ -799,6 +801,9 @@ class CIRGenFunction : public CIRGenTypeCache {
mlir::Value buildLoadOfScalar(LValue lvalue, clang::SourceLocation Loc);
mlir::Value buildLoadOfScalar(LValue lvalue, mlir::Location Loc);

/// Load a complex number from the specified l-value.
mlir::Value buildLoadOfComplex(LValue src, SourceLocation loc);

Address buildLoadOfReference(LValue RefLVal, mlir::Location Loc,
LValueBaseInfo *PointeeBaseInfo = nullptr);
LValue buildLoadOfReferenceLValue(LValue RefLVal, mlir::Location Loc);
Expand Down
10 changes: 6 additions & 4 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,6 @@ void LoweringPreparePass::lowerUnaryOp(UnaryOp op) {

auto loc = op.getLoc();
auto opKind = op.getKind();
assert((opKind == mlir::cir::UnaryOpKind::Plus ||
opKind == mlir::cir::UnaryOpKind::Minus ||
opKind == mlir::cir::UnaryOpKind::Not) &&
"invalid unary op kind on complex numbers");

CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);
Expand All @@ -372,6 +368,12 @@ void LoweringPreparePass::lowerUnaryOp(UnaryOp op) {
mlir::Value resultReal;
mlir::Value resultImag;
switch (opKind) {
case mlir::cir::UnaryOpKind::Inc:
case mlir::cir::UnaryOpKind::Dec:
resultReal = builder.createUnaryOp(loc, opKind, operandReal);
resultImag = operandImag;
break;

case mlir::cir::UnaryOpKind::Plus:
case mlir::cir::UnaryOpKind::Minus:
resultReal = builder.createUnaryOp(loc, opKind, operandReal);
Expand Down
140 changes: 140 additions & 0 deletions clang/test/CIR/CodeGen/complex-arithmetic.c
Original file line number Diff line number Diff line change
Expand Up @@ -776,3 +776,143 @@ void builtin_conj() {
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#RESI]], 1

// CHECK: }

void pre_increment() {
++cd1;
++ci1;
}

// CLANG: @pre_increment
// CPPLANG: @_Z13pre_incrementv

// CIRGEN: %{{.+}} = cir.unary(inc, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
// CIRGEN: %{{.+}} = cir.unary(inc, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#IR:]] = cir.unary(inc, %[[#R]]) : !cir.double, !cir.double
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !cir.double -> !cir.complex<!cir.double>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#IR:]] = cir.unary(inc, %[[#R]]) : !s32i, !s32i
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !s32i -> !cir.complex<!s32i>

// LLVM: %[[#R:]] = extractvalue { double, double } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { double, double } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = fadd double 1.000000e+00, %[[#R]]
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#I]], 1

// LLVM: %[[#R:]] = extractvalue { i32, i32 } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { i32, i32 } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = add i32 %[[#R]], 1
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#I]], 1

// CHECK: }

void post_increment() {
cd1++;
ci1++;
}

// CLANG: @post_increment
// CPPLANG: @_Z14post_incrementv

// CIRGEN: %{{.+}} = cir.unary(inc, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
// CIRGEN: %{{.+}} = cir.unary(inc, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#IR:]] = cir.unary(inc, %[[#R]]) : !cir.double, !cir.double
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !cir.double -> !cir.complex<!cir.double>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#IR:]] = cir.unary(inc, %[[#R]]) : !s32i, !s32i
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !s32i -> !cir.complex<!s32i>

// LLVM: %[[#R:]] = extractvalue { double, double } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { double, double } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = fadd double 1.000000e+00, %[[#R]]
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#I]], 1

// LLVM: %[[#R:]] = extractvalue { i32, i32 } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { i32, i32 } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = add i32 %[[#R]], 1
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#I]], 1

// CHECK: }

void pre_decrement() {
--cd1;
--ci1;
}

// CLANG: @pre_decrement
// CPPLANG: @_Z13pre_decrementv

// CIRGEN: %{{.+}} = cir.unary(dec, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
// CIRGEN: %{{.+}} = cir.unary(dec, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#IR:]] = cir.unary(dec, %[[#R]]) : !cir.double, !cir.double
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !cir.double -> !cir.complex<!cir.double>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#IR:]] = cir.unary(dec, %[[#R]]) : !s32i, !s32i
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !s32i -> !cir.complex<!s32i>

// LLVM: %[[#R:]] = extractvalue { double, double } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { double, double } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = fadd double -1.000000e+00, %[[#R]]
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#I]], 1

// LLVM: %[[#R:]] = extractvalue { i32, i32 } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { i32, i32 } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = sub i32 %[[#R]], 1
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#I]], 1

// CHECK: }

void post_decrement() {
cd1--;
ci1--;
}

// CLANG: @post_decrement
// CPPLANG: @_Z14post_decrementv

// CIRGEN: %{{.+}} = cir.unary(dec, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
// CIRGEN: %{{.+}} = cir.unary(dec, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#IR:]] = cir.unary(dec, %[[#R]]) : !cir.double, !cir.double
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !cir.double -> !cir.complex<!cir.double>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#IR:]] = cir.unary(dec, %[[#R]]) : !s32i, !s32i
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !s32i -> !cir.complex<!s32i>

// LLVM: %[[#R:]] = extractvalue { double, double } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { double, double } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = fadd double -1.000000e+00, %[[#R]]
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#I]], 1

// LLVM: %[[#R:]] = extractvalue { i32, i32 } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { i32, i32 } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = sub i32 %[[#R]], 1
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#I]], 1

// CHECK: }

0 comments on commit c1900b7

Please sign in to comment.