Skip to content

Commit

Permalink
[CIR][CodeGen] initial support for dynamic_cast
Browse files Browse the repository at this point in the history
This patch introduces CIR CodeGen support for dynamic_cast. As an initial step,
this patch only adds support for cases where the destination type is not void*.
Support for dynamic_cast to void* will be added in future patches.
  • Loading branch information
Lancern committed Jan 30, 2024
1 parent 9bc8b47 commit ba4b5aa
Show file tree
Hide file tree
Showing 11 changed files with 349 additions and 4 deletions.
4 changes: 4 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,10 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
alloca->moveAfter(*std::prev(allocas.end()));
}
}

mlir::Value createPtrIsNull(mlir::Value ptr) {
return createNot(createPtrToBoolCast(ptr));
}
};

} // namespace cir
Expand Down
11 changes: 11 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenCXXABI.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,17 @@ class CIRGenCXXABI {

virtual void buildRethrow(CIRGenFunction &CGF, bool isNoReturn) = 0;
virtual void buildThrow(CIRGenFunction &CGF, const CXXThrowExpr *E) = 0;

virtual void buildBadCastCall(CIRGenFunction &CGF, mlir::Location loc) = 0;

virtual bool shouldDynamicCastCallBeNullChecked(bool SrcIsPtr,
QualType SrcRecordTy) = 0;

virtual mlir::Value buildDynamicCastCall(CIRGenFunction &CGF,
mlir::Location Loc, Address Value,
QualType SrcRecordTy,
QualType DestTy,
QualType DestRecordTy) = 0;
};

/// Creates and Itanium-family ABI
Expand Down
13 changes: 13 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,19 @@ RValue CIRGenFunction::GetUndefRValue(QualType Ty) {
return RValue::get(nullptr);
}

mlir::Value CIRGenFunction::buildRuntimeCall(mlir::Location loc,
mlir::cir::FuncOp callee,
ArrayRef<mlir::Value> args) {
auto call = builder.create<mlir::cir::CallOp>(loc, callee, args);
assert(call->getNumResults() <= 1 &&
"runtime functions have at most 1 result");

if (call->getNumResults() == 0)
return nullptr;

return call->getResult(0);
}

void CIRGenFunction::buildCallArg(CallArgList &args, const Expr *E,
QualType type) {
// TODO: Add the DisableDebugLocationUpdates helper
Expand Down
8 changes: 6 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "CIRGenValue.h"
#include "UnimplementedFeatureGuarding.h"

#include "clang/AST/ExprCXX.h"
#include "clang/AST/GlobalDecl.h"
#include "clang/Basic/Builtins.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
Expand Down Expand Up @@ -1565,7 +1566,10 @@ LValue CIRGenFunction::buildCastLValue(const CastExpr *E) {
assert(0 && "NYI");

case CK_Dynamic: {
assert(0 && "NYI");
LValue LV = buildLValue(E->getSubExpr());
Address V = LV.getAddress();
const auto *DCE = cast<CXXDynamicCastExpr>(E);
return MakeNaturalAlignAddrLValue(buildDynamicCast(V, DCE), E->getType());
}

case CK_ConstructorConversion:
Expand Down Expand Up @@ -2172,7 +2176,6 @@ LValue CIRGenFunction::buildLValue(const Expr *E) {
return buildPredefinedLValue(cast<PredefinedExpr>(E));
case Expr::CStyleCastExprClass:
case Expr::CXXFunctionalCastExprClass:
case Expr::CXXDynamicCastExprClass:
case Expr::CXXReinterpretCastExprClass:
case Expr::CXXConstCastExprClass:
case Expr::CXXAddrspaceCastExprClass:
Expand All @@ -2181,6 +2184,7 @@ LValue CIRGenFunction::buildLValue(const Expr *E) {
<< E->getStmtClassName() << "'";
assert(0 && "Use buildCastLValue below, remove me when adding testcase");
case Expr::CXXStaticCastExprClass:
case Expr::CXXDynamicCastExprClass:
case Expr::ImplicitCastExprClass:
return buildCastLValue(cast<CastExpr>(E));
case Expr::OpaqueValueExprClass:
Expand Down
85 changes: 85 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenExprCXX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,3 +911,88 @@ void CIRGenFunction::buildDeleteCall(const FunctionDecl *DeleteFD,
llvm_unreachable("NYI"); // DestroyingDeleteTag->eraseFromParent();
}
}

static mlir::Value buildDynamicCastToNull(CIRGenFunction &CGF,
mlir::Location Loc, QualType DestTy) {
mlir::Type DestCIRTy = CGF.ConvertType(DestTy);
assert(DestCIRTy.isa<mlir::cir::PointerType>() &&
"result of dynamic_cast should be a ptr");

mlir::Value NullPtrValue = CGF.getBuilder().getNullPtr(DestCIRTy, Loc);

if (!DestTy->isPointerType()) {
/// C++ [expr.dynamic.cast]p9:
/// A failed cast to reference type throws std::bad_cast
CGF.CGM.getCXXABI().buildBadCastCall(CGF, Loc);
}

return NullPtrValue;
}

mlir::Value CIRGenFunction::buildDynamicCast(Address ThisAddr,
const CXXDynamicCastExpr *DCE) {
auto loc = getLoc(DCE->getSourceRange());

CGM.buildExplicitCastExprType(DCE, this);
QualType destTy = DCE->getTypeAsWritten();
QualType srcTy = DCE->getSubExpr()->getType();

// C++ [expr.dynamic.cast]p7:
// If T is "pointer to cv void," then the result is a pointer to the most
// derived object pointed to by v.
bool isDynCastToVoid = destTy->isVoidPointerType();
QualType srcRecordTy;
QualType destRecordTy;
if (isDynCastToVoid) {
srcRecordTy = srcTy->getPointeeType();
// No DestRecordTy.
} else if (const PointerType *DestPTy = destTy->getAs<PointerType>()) {
srcRecordTy = srcTy->castAs<PointerType>()->getPointeeType();
destRecordTy = DestPTy->getPointeeType();
} else {
srcRecordTy = srcTy;
destRecordTy = destTy->castAs<ReferenceType>()->getPointeeType();
}

buildTypeCheck(TCK_DynamicOperation, DCE->getExprLoc(), ThisAddr.getPointer(),
srcRecordTy);

if (DCE->isAlwaysNull())
return buildDynamicCastToNull(*this, loc, destTy);

assert(srcRecordTy->isRecordType() && "source type must be a record type!");

// C++ [expr.dynamic.cast]p4:
// If the value of v is a null pointer value in the pointer case, the result
// is the null pointer value of type T.
bool shouldNullCheckSrcValue =
CGM.getCXXABI().shouldDynamicCastCallBeNullChecked(srcTy->isPointerType(),
srcRecordTy);

auto buildDynamicCastAfterNullCheck = [&]() -> mlir::Value {
if (isDynCastToVoid)
llvm_unreachable("NYI");
else {
assert(destRecordTy->isRecordType() &&
"destination type must be a record type!");
return CGM.getCXXABI().buildDynamicCastCall(
*this, loc, ThisAddr, srcRecordTy, destTy, destRecordTy);
}
};

if (!shouldNullCheckSrcValue)
return buildDynamicCastAfterNullCheck();

mlir::Value srcValueIsNull = builder.createPtrIsNull(ThisAddr.getPointer());
return builder
.create<mlir::cir::TernaryOp>(
loc, srcValueIsNull,
[&](mlir::OpBuilder &, mlir::Location) {
builder.createYield(loc,
buildDynamicCastToNull(*this, loc, destTy));
},
[&](mlir::OpBuilder &, mlir::Location) {
builder.createYield(loc, buildDynamicCastAfterNullCheck());
})
.getResult();
}
7 changes: 5 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1399,8 +1399,11 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
// the alignment.
return CGF.buildPointerWithAlignment(CE).getPointer();
}
case CK_Dynamic:
llvm_unreachable("NYI");
case CK_Dynamic: {
Address V = CGF.buildPointerWithAlignment(E);
const auto *DCE = cast<CXXDynamicCastExpr>(CE);
return CGF.buildDynamicCast(V, DCE);
}
case CK_ArrayToPointerDecay:
return CGF.buildArrayToPointerDecay(E).getPointer();
case CK_FunctionToPointerDecay:
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ class CIRGenFunction : public CIRGenTypeCache {
TCK_MemberCall,
/// Checking the 'this' pointer for a constructor call.
TCK_ConstructorCall,
/// Checking the operand of a dynamic_cast or a typeid expression. Must be
/// null or an object within its lifetime.
TCK_DynamicOperation
};

// Holds coroutine data if the current function is a coroutine. We use a
Expand Down Expand Up @@ -630,6 +633,8 @@ class CIRGenFunction : public CIRGenTypeCache {
QualType DeleteTy, mlir::Value NumElements = nullptr,
CharUnits CookieSize = CharUnits());

mlir::Value buildDynamicCast(Address ThisAddr, const CXXDynamicCastExpr *DCE);

mlir::Value createLoad(const clang::VarDecl *VD, const char *Name);

mlir::Value buildScalarPrePostIncDec(const UnaryOperator *E, LValue LV,
Expand Down Expand Up @@ -794,6 +799,9 @@ class CIRGenFunction : public CIRGenTypeCache {
RValue buildCallExpr(const clang::CallExpr *E,
ReturnValueSlot ReturnValue = ReturnValueSlot());

mlir::Value buildRuntimeCall(mlir::Location loc, mlir::cir::FuncOp callee,
ArrayRef<mlir::Value> args = {});

/// Create a check for a function parameter that may potentially be
/// declared as non-null.
void buildNonNullArgCheck(RValue RV, QualType ArgType, SourceLocation ArgLoc,
Expand Down
144 changes: 144 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,18 @@ class CIRGenItaniumCXXABI : public cir::CIRGenCXXABI {
return Args.size() - 1;
}

void buildBadCastCall(CIRGenFunction &CGF, mlir::Location loc) override;

bool shouldDynamicCastCallBeNullChecked(bool SrcIsPtr,
QualType SrcRecordTy) override {
return SrcIsPtr;
}

mlir::Value buildDynamicCastCall(CIRGenFunction &CGF, mlir::Location Loc,
Address Value, QualType SrcRecordTy,
QualType DestTy,
QualType DestRecordTy) override;

/**************************** RTTI Uniqueness ******************************/
protected:
/// Returns true if the ABI requires RTTI type_info objects to be unique
Expand Down Expand Up @@ -2173,3 +2185,135 @@ void CIRGenItaniumCXXABI::buildThrow(CIRGenFunction &CGF,
builder.create<mlir::cir::ThrowOp>(CGF.getLoc(E->getSourceRange()),
exceptionPtr, typeInfo.getSymbol(), dtor);
}

static mlir::cir::FuncOp getBadCastFn(CIRGenFunction &CGF) {
// Prototype: void __cxa_bad_cast();
mlir::cir::FuncType FTy =
CGF.getBuilder().getFuncType({}, CGF.getBuilder().getVoidTy());
return CGF.CGM.getOrCreateRuntimeFunction(FTy, "__cxa_bad_cast");
}

void CIRGenItaniumCXXABI::buildBadCastCall(CIRGenFunction &CGF,
mlir::Location loc) {
CGF.buildRuntimeCall(loc, getBadCastFn(CGF));
// TODO(cir): mark the current insertion point as unreachable.
}

static CharUnits computeOffsetHint(ASTContext &Context,
const CXXRecordDecl *Src,
const CXXRecordDecl *Dst) {
CXXBasePaths Paths(/*FindAmbiguities=*/true, /*RecordPaths=*/true,
/*DetectVirtual=*/false);

// If Dst is not derived from Src we can skip the whole computation below and
// return that Src is not a public base of Dst. Record all inheritance paths.
if (!Dst->isDerivedFrom(Src, Paths))
return CharUnits::fromQuantity(-2ULL);

unsigned NumPublicPaths = 0;
CharUnits Offset;

// Now walk all possible inheritance paths.
for (const CXXBasePath &Path : Paths) {
if (Path.Access != AS_public) // Ignore non-public inheritance.
continue;

++NumPublicPaths;

for (const CXXBasePathElement &PathElement : Path) {
// If the path contains a virtual base class we can't give any hint.
// -1: no hint.
if (PathElement.Base->isVirtual())
return CharUnits::fromQuantity(-1ULL);

if (NumPublicPaths > 1) // Won't use offsets, skip computation.
continue;

// Accumulate the base class offsets.
const ASTRecordLayout &L = Context.getASTRecordLayout(PathElement.Class);
Offset += L.getBaseClassOffset(
PathElement.Base->getType()->getAsCXXRecordDecl());
}
}

// -2: Src is not a public base of Dst.
if (NumPublicPaths == 0)
return CharUnits::fromQuantity(-2ULL);

// -3: Src is a multiple public base type but never a virtual base type.
if (NumPublicPaths > 1)
return CharUnits::fromQuantity(-3ULL);

// Otherwise, the Src type is a unique public nonvirtual base type of Dst.
// Return the offset of Src from the origin of Dst.
return Offset;
}

static mlir::cir::FuncOp getItaniumDynamicCastFn(CIRGenFunction &CGF) {
// Prototype:
// void *__dynamic_cast(const void *sub,
// global_as const abi::__class_type_info *src,
// global_as const abi::__class_type_info *dst,
// std::ptrdiff_t src2dst_offset);

mlir::Type VoidPtrTy = CGF.VoidPtrTy;
mlir::Type RTTIPtrTy = CGF.getBuilder().getUInt8PtrTy();
mlir::Type PtrDiffTy = CGF.ConvertType(CGF.getContext().getPointerDiffType());

// TODO(cir): mark the function as nowind readonly.

mlir::cir::FuncType FTy = CGF.getBuilder().getFuncType(
{VoidPtrTy, RTTIPtrTy, RTTIPtrTy, PtrDiffTy}, VoidPtrTy);
return CGF.CGM.getOrCreateRuntimeFunction(FTy, "__dynamic_cast");
}

mlir::Value CIRGenItaniumCXXABI::buildDynamicCastCall(
CIRGenFunction &CGF, mlir::Location Loc, Address Value,
QualType SrcRecordTy, QualType DestTy, QualType DestRecordTy) {
mlir::Type ptrdiffTy = CGF.ConvertType(CGF.getContext().getPointerDiffType());

mlir::Value srcRtti = CGF.getBuilder().getConstant(
Loc,
CGF.CGM.getAddrOfRTTIDescriptor(Loc, SrcRecordTy.getUnqualifiedType())
.cast<mlir::TypedAttr>());
mlir::Value destRtti = CGF.getBuilder().getConstant(
Loc,
CGF.CGM.getAddrOfRTTIDescriptor(Loc, DestRecordTy.getUnqualifiedType())
.cast<mlir::TypedAttr>());

// Compute the offset hint.
const CXXRecordDecl *srcDecl = SrcRecordTy->getAsCXXRecordDecl();
const CXXRecordDecl *destDecl = DestRecordTy->getAsCXXRecordDecl();
mlir::Value offsetHint = CGF.getBuilder().getConstAPInt(
Loc, ptrdiffTy,
llvm::APSInt::get(computeOffsetHint(CGF.getContext(), srcDecl, destDecl)
.getQuantity()));

// Emit the call to __dynamic_cast.
mlir::Value srcPtr =
CGF.getBuilder().createBitcast(Value.getPointer(), CGF.VoidPtrTy);
mlir::Value args[4] = {srcPtr, srcRtti, destRtti, offsetHint};
mlir::Value castedPtr =
CGF.buildRuntimeCall(Loc, getItaniumDynamicCastFn(CGF), args);

assert(castedPtr.getType().isa<mlir::cir::PointerType>() &&
"the return value of __dynamic_cast should be a ptr");

/// C++ [expr.dynamic.cast]p9:
/// A failed cast to reference type throws std::bad_cast
if (DestTy->isReferenceType()) {
// Emit a cir.if that checks the casted value.
mlir::Value castedValueIsNull = CGF.getBuilder().createPtrIsNull(castedPtr);
CGF.getBuilder().create<mlir::cir::IfOp>(
Loc, castedValueIsNull, false, [&](mlir::OpBuilder &, mlir::Location) {
buildBadCastCall(CGF, Loc);
// TODO(cir): remove this once buildBadCastCall inserts unreachable
CGF.getBuilder().createYield(Loc);
});
}

// Note that castedPtr is a void*. Cast it to a pointer to the destination
// type before return.
mlir::Type destCIRTy = CGF.ConvertType(DestTy);
return CGF.getBuilder().createBitcast(castedPtr, destCIRTy);
}
11 changes: 11 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1977,6 +1977,17 @@ CIRGenModule::createCIRFunction(mlir::Location loc, StringRef name,
return f;
}

mlir::cir::FuncOp
CIRGenModule::getOrCreateRuntimeFunction(mlir::cir::FuncType Ty,
StringRef Name) {
auto entry = cast_if_present<mlir::cir::FuncOp>(getGlobalValue(Name));
if (entry)
return entry;

return createCIRFunction(mlir::UnknownLoc::get(builder.getContext()), Name,
Ty, nullptr);
}

bool isDefaultedMethod(const clang::FunctionDecl *FD) {
if (FD->isDefaulted() && isa<CXXMethodDecl>(FD) &&
(cast<CXXMethodDecl>(FD)->isCopyAssignmentOperator() ||
Expand Down
Loading

0 comments on commit ba4b5aa

Please sign in to comment.