Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][CIRGen] Complex unary increment and decrement operator #790

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,7 @@ LValue CIRGenFunction::buildUnaryOpLValue(const UnaryOperator *E) {
LValue LV = buildLValue(E->getSubExpr());

if (E->getType()->isAnyComplexType()) {
assert(0 && "not implemented");
buildComplexPrePostIncDec(E, LV, isInc, true /*isPre*/);
} else {
buildScalarPrePostIncDec(E, LV, isInc, isPre);
}
Expand Down
34 changes: 31 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,7 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {

// Operators.
mlir::Value VisitPrePostIncDec(const UnaryOperator *E, bool isInc,
bool isPre) {
llvm_unreachable("NYI");
}
bool isPre);
mlir::Value VisitUnaryPostDec(const UnaryOperator *E) {
return VisitPrePostIncDec(E, false, false);
}
Expand Down Expand Up @@ -537,6 +535,12 @@ mlir::Value ComplexExprEmitter::VisitCallExpr(const CallExpr *E) {
return CGF.buildCallExpr(E).getComplexVal();
}

mlir::Value ComplexExprEmitter::VisitPrePostIncDec(const UnaryOperator *E,
bool isInc, bool isPre) {
LValue LV = CGF.buildLValue(E->getSubExpr());
return CGF.buildComplexPrePostIncDec(E, LV, isInc, isPre);
}

mlir::Value ComplexExprEmitter::VisitUnaryPlus(const UnaryOperator *E,
QualType PromotionType) {
QualType promotionTy = PromotionType.isNull()
Expand Down Expand Up @@ -969,3 +973,27 @@ LValue CIRGenFunction::buildComplexCompoundAssignmentLValue(
RValue Val;
return ComplexExprEmitter(*this).buildCompoundAssignLValue(E, Op, Val);
}

mlir::Value CIRGenFunction::buildComplexPrePostIncDec(const UnaryOperator *E,
LValue LV, bool isInc,
bool isPre) {
mlir::Value InVal = buildLoadOfComplex(LV, E->getExprLoc());

auto Loc = getLoc(E->getExprLoc());
auto OpKind =
isInc ? mlir::cir::UnaryOpKind::Inc : mlir::cir::UnaryOpKind::Dec;
mlir::Value IncVal = builder.createUnaryOp(Loc, OpKind, InVal);

// Store the updated result through the lvalue.
buildStoreOfComplex(Loc, IncVal, LV, /*init*/ false);
if (getLangOpts().OpenMP)
llvm_unreachable("NYI");

// If this is a postinc, return the value read from memory, otherwise use the
// updated value.
return isPre ? IncVal : InVal;
}

mlir::Value CIRGenFunction::buildLoadOfComplex(LValue src, SourceLocation loc) {
return ComplexExprEmitter(*this).buildLoadOfLValue(src, loc);
}
5 changes: 5 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,8 @@ class CIRGenFunction : public CIRGenTypeCache {

mlir::Value buildScalarPrePostIncDec(const UnaryOperator *E, LValue LV,
bool isInc, bool isPre);
mlir::Value buildComplexPrePostIncDec(const UnaryOperator *E, LValue LV,
bool isInc, bool isPre);

// Wrapper for function prototype sources. Wraps either a FunctionProtoType or
// an ObjCMethodDecl.
Expand Down Expand Up @@ -799,6 +801,9 @@ class CIRGenFunction : public CIRGenTypeCache {
mlir::Value buildLoadOfScalar(LValue lvalue, clang::SourceLocation Loc);
mlir::Value buildLoadOfScalar(LValue lvalue, mlir::Location Loc);

/// Load a complex number from the specified l-value.
mlir::Value buildLoadOfComplex(LValue src, SourceLocation loc);

Address buildLoadOfReference(LValue RefLVal, mlir::Location Loc,
LValueBaseInfo *PointeeBaseInfo = nullptr);
LValue buildLoadOfReferenceLValue(LValue RefLVal, mlir::Location Loc);
Expand Down
10 changes: 6 additions & 4 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,6 @@ void LoweringPreparePass::lowerUnaryOp(UnaryOp op) {

auto loc = op.getLoc();
auto opKind = op.getKind();
assert((opKind == mlir::cir::UnaryOpKind::Plus ||
opKind == mlir::cir::UnaryOpKind::Minus ||
opKind == mlir::cir::UnaryOpKind::Not) &&
"invalid unary op kind on complex numbers");

CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);
Expand All @@ -372,6 +368,12 @@ void LoweringPreparePass::lowerUnaryOp(UnaryOp op) {
mlir::Value resultReal;
mlir::Value resultImag;
switch (opKind) {
case mlir::cir::UnaryOpKind::Inc:
case mlir::cir::UnaryOpKind::Dec:
resultReal = builder.createUnaryOp(loc, opKind, operandReal);
resultImag = operandImag;
break;

case mlir::cir::UnaryOpKind::Plus:
case mlir::cir::UnaryOpKind::Minus:
resultReal = builder.createUnaryOp(loc, opKind, operandReal);
Expand Down
140 changes: 140 additions & 0 deletions clang/test/CIR/CodeGen/complex-arithmetic.c
Original file line number Diff line number Diff line change
Expand Up @@ -776,3 +776,143 @@ void builtin_conj() {
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#RESI]], 1

// CHECK: }

void pre_increment() {
++cd1;
++ci1;
}

// CLANG: @pre_increment
// CPPLANG: @_Z13pre_incrementv

// CIRGEN: %{{.+}} = cir.unary(inc, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
// CIRGEN: %{{.+}} = cir.unary(inc, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#IR:]] = cir.unary(inc, %[[#R]]) : !cir.double, !cir.double
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !cir.double -> !cir.complex<!cir.double>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#IR:]] = cir.unary(inc, %[[#R]]) : !s32i, !s32i
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !s32i -> !cir.complex<!s32i>

// LLVM: %[[#R:]] = extractvalue { double, double } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { double, double } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = fadd double 1.000000e+00, %[[#R]]
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#I]], 1

// LLVM: %[[#R:]] = extractvalue { i32, i32 } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { i32, i32 } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = add i32 %[[#R]], 1
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#I]], 1

// CHECK: }

void post_increment() {
cd1++;
ci1++;
}

// CLANG: @post_increment
// CPPLANG: @_Z14post_incrementv

// CIRGEN: %{{.+}} = cir.unary(inc, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
// CIRGEN: %{{.+}} = cir.unary(inc, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#IR:]] = cir.unary(inc, %[[#R]]) : !cir.double, !cir.double
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !cir.double -> !cir.complex<!cir.double>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#IR:]] = cir.unary(inc, %[[#R]]) : !s32i, !s32i
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !s32i -> !cir.complex<!s32i>

// LLVM: %[[#R:]] = extractvalue { double, double } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { double, double } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = fadd double 1.000000e+00, %[[#R]]
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#I]], 1

// LLVM: %[[#R:]] = extractvalue { i32, i32 } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { i32, i32 } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = add i32 %[[#R]], 1
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#I]], 1

// CHECK: }

void pre_decrement() {
--cd1;
--ci1;
}

// CLANG: @pre_decrement
// CPPLANG: @_Z13pre_decrementv

// CIRGEN: %{{.+}} = cir.unary(dec, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
// CIRGEN: %{{.+}} = cir.unary(dec, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#IR:]] = cir.unary(dec, %[[#R]]) : !cir.double, !cir.double
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !cir.double -> !cir.complex<!cir.double>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#IR:]] = cir.unary(dec, %[[#R]]) : !s32i, !s32i
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !s32i -> !cir.complex<!s32i>

// LLVM: %[[#R:]] = extractvalue { double, double } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { double, double } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = fadd double -1.000000e+00, %[[#R]]
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#I]], 1

// LLVM: %[[#R:]] = extractvalue { i32, i32 } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { i32, i32 } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = sub i32 %[[#R]], 1
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#I]], 1

// CHECK: }

void post_decrement() {
cd1--;
ci1--;
}

// CLANG: @post_decrement
// CPPLANG: @_Z14post_decrementv

// CIRGEN: %{{.+}} = cir.unary(dec, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
// CIRGEN: %{{.+}} = cir.unary(dec, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#IR:]] = cir.unary(dec, %[[#R]]) : !cir.double, !cir.double
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !cir.double -> !cir.complex<!cir.double>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#IR:]] = cir.unary(dec, %[[#R]]) : !s32i, !s32i
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !s32i -> !cir.complex<!s32i>

// LLVM: %[[#R:]] = extractvalue { double, double } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { double, double } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = fadd double -1.000000e+00, %[[#R]]
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#I]], 1

// LLVM: %[[#R:]] = extractvalue { i32, i32 } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { i32, i32 } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = sub i32 %[[#R]], 1
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#I]], 1

// CHECK: }