From fde55c947eef379b97b1bdd700242d678fe8f767 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Thu, 2 Nov 2023 17:58:03 +0200 Subject: [PATCH] Introduce type cloning. --- include/clad/Differentiator/StmtClone.h | 8 ++ include/clad/Differentiator/VisitorBase.h | 3 + lib/Differentiator/ReverseModeVisitor.cpp | 4 +- lib/Differentiator/StmtClone.cpp | 109 ++++++++++++++-------- lib/Differentiator/VisitorBase.cpp | 7 ++ 5 files changed, 88 insertions(+), 43 deletions(-) diff --git a/include/clad/Differentiator/StmtClone.h b/include/clad/Differentiator/StmtClone.h index 61e79e6b6..ee5d2e801 100644 --- a/include/clad/Differentiator/StmtClone.h +++ b/include/clad/Differentiator/StmtClone.h @@ -48,6 +48,10 @@ namespace utils { template StmtTy* Clone(const StmtTy* S); + /// Cloning types is necessary since VariableArrayType + /// store a pointer to their size expression. + clang::QualType CloneType(const clang::QualType T); + // visitor part (not for public use) // Stmt.def could be used if ABSTR_STMT is introduced #define DECLARE_CLONE_FN(CLASS) clang::Stmt* Visit ## CLASS(clang::CLASS *Node); @@ -153,6 +157,10 @@ namespace utils { ReferencesUpdater(clang::Sema& SemaRef, clang::Scope* S, const clang::FunctionDecl* FD); bool VisitDeclRefExpr(clang::DeclRefExpr* DRE); + bool VisitStmt(clang::Stmt* S); + /// Used to update the size expression of QT + /// if QT is VariableArrayType. + void updateType(clang::QualType QT); }; } // namespace utils } // namespace clad diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 81e4af9a2..9036c2c96 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -558,6 +558,9 @@ namespace clad { clang::Stmt* Clone(const clang::Stmt* S); /// A shorthand to simplify cloning of expressions. clang::Expr* Clone(const clang::Expr* E); + /// Cloning types is necessary since VariableArrayType + /// store a pointer to their size expression. + clang::QualType CloneType(const clang::QualType T); }; } // end namespace clad diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 4e5d7020e..cd39fb06f 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2534,11 +2534,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Here separate behaviour for record and non-record types is only // necessary to preserve the old tests. if (VD->getType()->isRecordType()) - VDClone = BuildVarDecl(VD->getType(), VD->getNameAsString(), + VDClone = BuildVarDecl(CloneType(VD->getType()), VD->getNameAsString(), initDiff.getExpr(), VD->isDirectInit(), VD->getTypeSourceInfo(), VD->getInitStyle()); else - VDClone = BuildVarDecl(VD->getType(), VD->getNameAsString(), + VDClone = BuildVarDecl(CloneType(VD->getType()), VD->getNameAsString(), initDiff.getExpr(), VD->isDirectInit()); Expr* derivedVDE = BuildDeclRef(VDDerived); diff --git a/lib/Differentiator/StmtClone.cpp b/lib/Differentiator/StmtClone.cpp index 241fe79ff..4dbbdd15d 100644 --- a/lib/Differentiator/StmtClone.cpp +++ b/lib/Differentiator/StmtClone.cpp @@ -63,19 +63,19 @@ Stmt* StmtClone::Visit ## CLASS(CLASS *Node) \ return result; \ } -DEFINE_CLONE_EXPR_CO11(BinaryOperator, (CLAD_COMPAT_CLANG11_Ctx_ExtraParams Clone(Node->getLHS()), Clone(Node->getRHS()), Node->getOpcode(), Node->getType(), Node->getValueKind(), Node->getObjectKind(), Node->getOperatorLoc(), Node->getFPFeatures(CLAD_COMPAT_CLANG11_LangOptions_EtraParams))) -DEFINE_CLONE_EXPR_CO11(UnaryOperator, (CLAD_COMPAT_CLANG11_Ctx_ExtraParams Clone(Node->getSubExpr()), Node->getOpcode(), Node->getType(), Node->getValueKind(), Node->getObjectKind(), Node->getOperatorLoc() CLAD_COMPAT_CLANG7_UnaryOperator_ExtraParams CLAD_COMPAT_CLANG11_UnaryOperator_ExtraParams)) +DEFINE_CLONE_EXPR_CO11(BinaryOperator, (CLAD_COMPAT_CLANG11_Ctx_ExtraParams Clone(Node->getLHS()), Clone(Node->getRHS()), Node->getOpcode(), CloneType(Node->getType()), Node->getValueKind(), Node->getObjectKind(), Node->getOperatorLoc(), Node->getFPFeatures(CLAD_COMPAT_CLANG11_LangOptions_EtraParams))) +DEFINE_CLONE_EXPR_CO11(UnaryOperator, (CLAD_COMPAT_CLANG11_Ctx_ExtraParams Clone(Node->getSubExpr()), Node->getOpcode(), CloneType(Node->getType()), Node->getValueKind(), Node->getObjectKind(), Node->getOperatorLoc() CLAD_COMPAT_CLANG7_UnaryOperator_ExtraParams CLAD_COMPAT_CLANG11_UnaryOperator_ExtraParams)) Stmt* StmtClone::VisitDeclRefExpr(DeclRefExpr *Node) { TemplateArgumentListInfo TAListInfo; Node->copyTemplateArgumentsInto(TAListInfo); - return DeclRefExpr::Create(Ctx, Node->getQualifierLoc(), Node->getTemplateKeywordLoc(), Node->getDecl(), Node->refersToEnclosingVariableOrCapture(), Node->getNameInfo(), Node->getType(), Node->getValueKind(), Node->getFoundDecl(), &TAListInfo); + return DeclRefExpr::Create(Ctx, Node->getQualifierLoc(), Node->getTemplateKeywordLoc(), Node->getDecl(), Node->refersToEnclosingVariableOrCapture(), Node->getNameInfo(), CloneType(Node->getType()), Node->getValueKind(), Node->getFoundDecl(), &TAListInfo); } -DEFINE_CREATE_EXPR(IntegerLiteral, (Ctx, Node->getValue(), Node->getType(), Node->getLocation())) -DEFINE_CLONE_EXPR_CO(PredefinedExpr, (CLAD_COMPAT_CLANG8_Ctx_ExtraParams Node->getLocation(), Node->getType(), Node->getIdentKind() CLAD_COMPAT_CLANG17_IsTransparent(Node), Node->getFunctionName())) -DEFINE_CLONE_EXPR(CharacterLiteral, (Node->getValue(), Node->getKind(), Node->getType(), Node->getLocation())) -DEFINE_CLONE_EXPR(ImaginaryLiteral, (Clone(Node->getSubExpr()), Node->getType())) +DEFINE_CREATE_EXPR(IntegerLiteral, (Ctx, Node->getValue(), CloneType(Node->getType()), Node->getLocation())) +DEFINE_CLONE_EXPR_CO(PredefinedExpr, (CLAD_COMPAT_CLANG8_Ctx_ExtraParams Node->getLocation(), CloneType(Node->getType()), Node->getIdentKind() CLAD_COMPAT_CLANG17_IsTransparent(Node), Node->getFunctionName())) +DEFINE_CLONE_EXPR(CharacterLiteral, (Node->getValue(), Node->getKind(), CloneType(Node->getType()), Node->getLocation())) +DEFINE_CLONE_EXPR(ImaginaryLiteral, (Clone(Node->getSubExpr()), CloneType(Node->getType()))) DEFINE_CLONE_EXPR(ParenExpr, (Node->getLParen(), Node->getRParen(), Clone(Node->getSubExpr()))) -DEFINE_CLONE_EXPR(ArraySubscriptExpr, (Clone(Node->getLHS()), Clone(Node->getRHS()), Node->getType(), Node->getValueKind(), Node->getObjectKind(), Node->getRBracketLoc())) +DEFINE_CLONE_EXPR(ArraySubscriptExpr, (Clone(Node->getLHS()), Clone(Node->getRHS()), CloneType(Node->getType()), Node->getValueKind(), Node->getObjectKind(), Node->getRBracketLoc())) DEFINE_CREATE_EXPR(CXXDefaultArgExpr, (Ctx, SourceLocation(), Node->getParam() CLAD_COMPAT_CLANG16_CXXDefaultArgExpr_getRewrittenExpr_Param(Node) CLAD_COMPAT_CLANG9_CXXDefaultArgExpr_getUsedContext_Param(Node))) Stmt* StmtClone::VisitMemberExpr(MemberExpr* Node) { @@ -92,24 +92,24 @@ Stmt* StmtClone::VisitMemberExpr(MemberExpr* Node) { Node->getFoundDecl(), Node->getMemberNameInfo(), &TemplateArgs, - Node->getType(), + CloneType(Node->getType()), Node->getValueKind(), Node->getObjectKind() - CLAD_COMPAT_CLANG9_MemberExpr_ExtraParams( - Node->isNonOdrUse())); + CLAD_COMPAT_CLANG9_MemberExpr_ExtraParams( + Node->isNonOdrUse())); // Copy Value and Type dependent clad_compat::ExprSetDeps(result, Node); return result; } -DEFINE_CLONE_EXPR(CompoundLiteralExpr, (Node->getLParenLoc(), Node->getTypeSourceInfo(), Node->getType(), Node->getValueKind(), Clone(Node->getInitializer()), Node->isFileScope())) -DEFINE_CREATE_EXPR(ImplicitCastExpr, (Ctx, Node->getType(), Node->getCastKind(), Clone(Node->getSubExpr()), 0, Node->getValueKind() /*EP*/CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node) )) -DEFINE_CREATE_EXPR(CStyleCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Node->getCastKind(), Clone(Node->getSubExpr()), 0 /*EP*/CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node), Node->getTypeInfoAsWritten(), Node->getLParenLoc(), Node->getRParenLoc())) -DEFINE_CREATE_EXPR(CXXStaticCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Node->getCastKind(), Clone(Node->getSubExpr()), 0, Node->getTypeInfoAsWritten() /*EP*/CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) -DEFINE_CREATE_EXPR(CXXDynamicCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Node->getCastKind(), Clone(Node->getSubExpr()), 0, Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) -DEFINE_CREATE_EXPR(CXXReinterpretCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Node->getCastKind(), Clone(Node->getSubExpr()), 0, Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) -DEFINE_CREATE_EXPR(CXXConstCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Clone(Node->getSubExpr()), Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) -DEFINE_CREATE_EXPR(CXXConstructExpr, (Ctx, Node->getType(), Node->getLocation(), Node->getConstructor(), Node->isElidable(), clad_compat::makeArrayRef(Node->getArgs(), Node->getNumArgs()), Node->hadMultipleCandidates(), Node->isListInitialization(), Node->isStdInitListInitialization(), Node->requiresZeroInitialization(), Node->getConstructionKind(), Node->getParenOrBraceRange())) -DEFINE_CREATE_EXPR(CXXFunctionalCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Node->getTypeInfoAsWritten(), Node->getCastKind(), Clone(Node->getSubExpr()), 0 /*EP*/CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node), Node->getLParenLoc(), Node->getRParenLoc())) +DEFINE_CLONE_EXPR(CompoundLiteralExpr, (Node->getLParenLoc(), Node->getTypeSourceInfo(), CloneType(Node->getType()), Node->getValueKind(), Clone(Node->getInitializer()), Node->isFileScope())) +DEFINE_CREATE_EXPR(ImplicitCastExpr, (Ctx, CloneType(Node->getType()), Node->getCastKind(), Clone(Node->getSubExpr()), 0, Node->getValueKind() /*EP*/CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node) )) +DEFINE_CREATE_EXPR(CStyleCastExpr, (Ctx, CloneType(Node->getType()), Node->getValueKind(), Node->getCastKind(), Clone(Node->getSubExpr()), 0 /*EP*/CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node), Node->getTypeInfoAsWritten(), Node->getLParenLoc(), Node->getRParenLoc())) +DEFINE_CREATE_EXPR(CXXStaticCastExpr, (Ctx, CloneType(Node->getType()), Node->getValueKind(), Node->getCastKind(), Clone(Node->getSubExpr()), 0, Node->getTypeInfoAsWritten() /*EP*/CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) +DEFINE_CREATE_EXPR(CXXDynamicCastExpr, (Ctx, CloneType(Node->getType()), Node->getValueKind(), Node->getCastKind(), Clone(Node->getSubExpr()), 0, Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) +DEFINE_CREATE_EXPR(CXXReinterpretCastExpr, (Ctx, CloneType(Node->getType()), Node->getValueKind(), Node->getCastKind(), Clone(Node->getSubExpr()), 0, Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) +DEFINE_CREATE_EXPR(CXXConstCastExpr, (Ctx, CloneType(Node->getType()), Node->getValueKind(), Clone(Node->getSubExpr()), Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) +DEFINE_CREATE_EXPR(CXXConstructExpr, (Ctx, CloneType(Node->getType()), Node->getLocation(), Node->getConstructor(), Node->isElidable(), clad_compat::makeArrayRef(Node->getArgs(), Node->getNumArgs()), Node->hadMultipleCandidates(), Node->isListInitialization(), Node->isStdInitListInitialization(), Node->requiresZeroInitialization(), Node->getConstructionKind(), Node->getParenOrBraceRange())) +DEFINE_CREATE_EXPR(CXXFunctionalCastExpr, (Ctx, CloneType(Node->getType()), Node->getValueKind(), Node->getTypeInfoAsWritten(), Node->getCastKind(), Clone(Node->getSubExpr()), 0 /*EP*/CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node), Node->getLParenLoc(), Node->getRParenLoc())) DEFINE_CREATE_EXPR(ExprWithCleanups, (Ctx, Node->getSubExpr(), Node->cleanupsHaveSideEffects(), {})) // clang <= 7 do not have `ConstantExpr` node. @@ -117,27 +117,27 @@ DEFINE_CREATE_EXPR(ExprWithCleanups, (Ctx, Node->getSubExpr(), DEFINE_CREATE_EXPR(ConstantExpr, (Ctx, Clone(Node->getSubExpr()) CLAD_COMPAT_ConstantExpr_Create_ExtraParams)) #endif -DEFINE_CLONE_EXPR_CO(CXXTemporaryObjectExpr, (Ctx, Node->getConstructor(), Node->getType(), Node->getTypeSourceInfo(), clad_compat::makeArrayRef(Node->getArgs(), Node->getNumArgs()), Node->getSourceRange(), Node->hadMultipleCandidates(), Node->isListInitialization(), Node->isStdInitListInitialization(), Node->requiresZeroInitialization())) +DEFINE_CLONE_EXPR_CO(CXXTemporaryObjectExpr, (Ctx, Node->getConstructor(), CloneType(Node->getType()), Node->getTypeSourceInfo(), clad_compat::makeArrayRef(Node->getArgs(), Node->getNumArgs()), Node->getSourceRange(), Node->hadMultipleCandidates(), Node->isListInitialization(), Node->isStdInitListInitialization(), Node->requiresZeroInitialization())) -DEFINE_CLONE_EXPR(MaterializeTemporaryExpr, (Node->getType(), CLAD_COMPAT_CLANG10_GetTemporaryExpr(Node), Node->isBoundToLvalueReference())) -DEFINE_CLONE_EXPR_CO11(CompoundAssignOperator, (CLAD_COMPAT_CLANG11_Ctx_ExtraParams Clone(Node->getLHS()), Clone(Node->getRHS()), Node->getOpcode(), Node->getType(), +DEFINE_CLONE_EXPR(MaterializeTemporaryExpr, (CloneType(Node->getType()), CLAD_COMPAT_CLANG10_GetTemporaryExpr(Node), Node->isBoundToLvalueReference())) +DEFINE_CLONE_EXPR_CO11(CompoundAssignOperator, (CLAD_COMPAT_CLANG11_Ctx_ExtraParams Clone(Node->getLHS()), Clone(Node->getRHS()), Node->getOpcode(), CloneType(Node->getType()), Node->getValueKind(), Node->getObjectKind(), CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Removed Node->getOperatorLoc(), Node->getFPFeatures(CLAD_COMPAT_CLANG11_LangOptions_EtraParams) CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Moved)) -DEFINE_CLONE_EXPR(ConditionalOperator, (Clone(Node->getCond()), Node->getQuestionLoc(), Clone(Node->getLHS()), Node->getColonLoc(), Clone(Node->getRHS()), Node->getType(), Node->getValueKind(), Node->getObjectKind())) -DEFINE_CLONE_EXPR(AddrLabelExpr, (Node->getAmpAmpLoc(), Node->getLabelLoc(), Node->getLabel(), Node->getType())) -DEFINE_CLONE_EXPR(StmtExpr, (Clone(Node->getSubStmt()), Node->getType(), Node->getLParenLoc(), Node->getRParenLoc() CLAD_COMPAT_CLANG10_StmtExpr_Create_ExtraParams )) -DEFINE_CLONE_EXPR(ChooseExpr, (Node->getBuiltinLoc(), Clone(Node->getCond()), Clone(Node->getLHS()), Clone(Node->getRHS()), Node->getType(), Node->getValueKind(), Node->getObjectKind(), Node->getRParenLoc(), Node->isConditionTrue() CLAD_COMPAT_CLANG11_ChooseExpr_EtraParams_Removed)) -DEFINE_CLONE_EXPR(GNUNullExpr, (Node->getType(), Node->getTokenLocation())) -DEFINE_CLONE_EXPR(VAArgExpr, (Node->getBuiltinLoc(), Clone(Node->getSubExpr()), Node->getWrittenTypeInfo(), Node->getRParenLoc(), Node->getType(), Node->isMicrosoftABI())) -DEFINE_CLONE_EXPR(ImplicitValueInitExpr, (Node->getType())) +DEFINE_CLONE_EXPR(ConditionalOperator, (Clone(Node->getCond()), Node->getQuestionLoc(), Clone(Node->getLHS()), Node->getColonLoc(), Clone(Node->getRHS()), CloneType(Node->getType()), Node->getValueKind(), Node->getObjectKind())) +DEFINE_CLONE_EXPR(AddrLabelExpr, (Node->getAmpAmpLoc(), Node->getLabelLoc(), Node->getLabel(), CloneType(Node->getType()))) +DEFINE_CLONE_EXPR(StmtExpr, (Clone(Node->getSubStmt()), CloneType(Node->getType()), Node->getLParenLoc(), Node->getRParenLoc() CLAD_COMPAT_CLANG10_StmtExpr_Create_ExtraParams )) +DEFINE_CLONE_EXPR(ChooseExpr, (Node->getBuiltinLoc(), Clone(Node->getCond()), Clone(Node->getLHS()), Clone(Node->getRHS()), CloneType(Node->getType()), Node->getValueKind(), Node->getObjectKind(), Node->getRParenLoc(), Node->isConditionTrue() CLAD_COMPAT_CLANG11_ChooseExpr_EtraParams_Removed)) +DEFINE_CLONE_EXPR(GNUNullExpr, (CloneType(Node->getType()), Node->getTokenLocation())) +DEFINE_CLONE_EXPR(VAArgExpr, (Node->getBuiltinLoc(), Clone(Node->getSubExpr()), Node->getWrittenTypeInfo(), Node->getRParenLoc(), CloneType(Node->getType()), Node->isMicrosoftABI())) +DEFINE_CLONE_EXPR(ImplicitValueInitExpr, (CloneType(Node->getType()))) DEFINE_CLONE_EXPR(ExtVectorElementExpr, (Node->getType(), Node->getValueKind(), Clone(Node->getBase()), Node->getAccessor(), Node->getAccessorLoc())) DEFINE_CLONE_EXPR(CXXBoolLiteralExpr, (Node->getValue(), Node->getType(), Node->getSourceRange().getBegin())) DEFINE_CLONE_EXPR(CXXNullPtrLiteralExpr, (Node->getType(), Node->getSourceRange().getBegin())) DEFINE_CLONE_EXPR(CXXThisExpr, (Node->getSourceRange().getBegin(), Node->getType(), Node->isImplicit())) DEFINE_CLONE_EXPR(CXXThrowExpr, (Clone(Node->getSubExpr()), Node->getType(), Node->getThrowLoc(), Node->isThrownVariableInScope())) #if CLANG_VERSION_MAJOR < 16 -DEFINE_CLONE_EXPR(SubstNonTypeTemplateParmExpr, (Node->getType(), Node->getValueKind(), Node->getBeginLoc(), Node->getParameter(), CLAD_COMPAT_SubstNonTypeTemplateParmExpr_isReferenceParameter_ExtraParam(Node) Node->getReplacement())) +DEFINE_CLONE_EXPR(SubstNonTypeTemplateParmExpr, (CloneType(Node->getType()), Node->getValueKind(), Node->getBeginLoc(), Node->getParameter(), CLAD_COMPAT_SubstNonTypeTemplateParmExpr_isReferenceParameter_ExtraParam(Node) Node->getReplacement())) #else -DEFINE_CLONE_EXPR(SubstNonTypeTemplateParmExpr, (Node->getType(), Node->getValueKind(), Node->getBeginLoc(), Node->getReplacement(), Node->getAssociatedDecl(), Node->getIndex(), Node->getPackIndex(), Node->isReferenceParameter())); +DEFINE_CLONE_EXPR(SubstNonTypeTemplateParmExpr, (CloneType(Node->getType()), Node->getValueKind(), Node->getBeginLoc(), Node->getReplacement(), Node->getAssociatedDecl(), Node->getIndex(), Node->getPackIndex(), Node->isReferenceParameter())); #endif DEFINE_CREATE_EXPR(PseudoObjectExpr, (Ctx, Node->getSyntacticForm(), llvm::SmallVector(Node->semantics_begin(), Node->semantics_end()), Node->getResultExprIndex())) //BlockExpr @@ -147,14 +147,14 @@ Stmt* StmtClone::VisitStringLiteral(StringLiteral* Node) { llvm::SmallVector concatLocations(Node->tokloc_begin(), Node->tokloc_end()); return StringLiteral::Create(Ctx, Node->getString(), Node->getKind(), - Node->isPascal(), Node->getType(), + Node->isPascal(), CloneType(Node->getType()), &concatLocations[0], concatLocations.size()); } Stmt* StmtClone::VisitFloatingLiteral(FloatingLiteral* Node) { FloatingLiteral* clone = FloatingLiteral::Create(Ctx, Node->getValue(), Node->isExact(), - Node->getType(), + CloneType(Node->getType()), Node->getLocation()); clone->setSemantics(Node->getSemantics()); return clone; @@ -195,12 +195,12 @@ Stmt* StmtClone::VisitUnaryExprOrTypeTraitExpr(UnaryExprOrTypeTraitExpr* Node) { if (Node->isArgumentType()) return new (Ctx) UnaryExprOrTypeTraitExpr(Node->getKind(), Node->getArgumentTypeInfo(), - Node->getType(), + CloneType(Node->getType()), Node->getOperatorLoc(), Node->getRParenLoc()); return new (Ctx) UnaryExprOrTypeTraitExpr(Node->getKind(), Clone(Node->getArgumentExpr()), - Node->getType(), + CloneType(Node->getType()), Node->getOperatorLoc(), Node->getRParenLoc()); } @@ -208,7 +208,7 @@ Stmt* StmtClone::VisitUnaryExprOrTypeTraitExpr(UnaryExprOrTypeTraitExpr* Node) { Stmt* StmtClone::VisitCallExpr(CallExpr* Node) { CallExpr* result = clad_compat::CallExpr_Create(Ctx, Clone(Node->getCallee()), llvm::ArrayRef(), - Node->getType(), + CloneType(Node->getType()), Node->getValueKind(), Node->getRParenLoc() CLAD_COMPAT_CLANG8_CallExpr_ExtraParams); @@ -248,7 +248,7 @@ Stmt* StmtClone::VisitCXXOperatorCallExpr(CXXOperatorCallExpr* Node) { } CXXOperatorCallExpr* result = clad_compat::CXXOperatorCallExpr_Create( Ctx, Node->getOperator(), Clone(Node->getCallee()), clonedArgs, - Node->getType(), Node->getValueKind(), Node->getRParenLoc(), + CloneType(Node->getType()), Node->getValueKind(), Node->getRParenLoc(), Node->getFPFeatures() CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParams); //### result->setNumArgs(Ctx, Node->getNumArgs()); @@ -265,7 +265,7 @@ Stmt* StmtClone::VisitCXXOperatorCallExpr(CXXOperatorCallExpr* Node) { Stmt* StmtClone::VisitCXXMemberCallExpr(CXXMemberCallExpr * Node) { CXXMemberCallExpr* result = clad_compat::CXXMemberCallExpr_Create(Ctx, Clone(Node->getCallee()), 0, - Node->getType(), + CloneType(Node->getType()), Node->getValueKind(), Node->getRParenLoc() /*FP*/CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node) @@ -288,7 +288,7 @@ Stmt* StmtClone::VisitShuffleVectorExpr(ShuffleVectorExpr* Node) { cloned[i] = Clone(Node->getExpr(i)); llvm::ArrayRef clonedRef = clad_compat::makeArrayRef(cloned.data(), cloned.size()); - return new (Ctx) ShuffleVectorExpr(Ctx, clonedRef, Node->getType(), + return new (Ctx) ShuffleVectorExpr(Ctx, clonedRef, CloneType(Node->getType()), Node->getBuiltinLoc(), Node->getRParenLoc()); } @@ -366,7 +366,7 @@ Decl* StmtClone::CloneDecl(Decl* Node) { VarDecl* cloned_Decl = VarDecl::Create(Ctx, VD->getDeclContext(), VD->getLocation(), VD->getInnerLocStart(), - VD->getIdentifier(), VD->getType(), + VD->getIdentifier(), CloneType(VD->getType()), VD->getTypeSourceInfo(), VD->getStorageClass()); if (VD->getInit()) @@ -435,9 +435,36 @@ bool ReferencesUpdater::VisitDeclRefExpr(DeclRefExpr* DRE) { VD->setReferenced(); VD->setIsUsed(); } + updateType(DRE->getType()); return true; } +bool ReferencesUpdater::VisitStmt(clang::Stmt* S) { + if (auto* E = dyn_cast(S)) { + updateType(E->getType()); + } + return true; +} + +void ReferencesUpdater::updateType(QualType QT) { + if (auto* varArrType = dyn_cast(QT)) { + TraverseStmt(varArrType->getSizeExpr()); + } +} + +QualType StmtClone::CloneType(const clang::QualType T) { + if (const auto* varArrType = + dyn_cast(T.getTypePtr())) { + auto elemType = varArrType->getElementType(); + return Ctx.getVariableArrayType(elemType, Clone(varArrType->getSizeExpr()), + varArrType->getSizeModifier(), + T.getQualifiers().getAsOpaqueValue(), + SourceRange()); + } + + return clang::QualType(T.getTypePtr(), T.getQualifiers().getAsOpaqueValue()); +} + //--------------------------------------------------------- } // end namespace utils } // end namespace clad diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 90f42dc21..8ee1fd798 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -301,6 +301,13 @@ namespace clad { return llvm::cast(Clone(S)); } + + QualType VisitorBase::CloneType(const QualType QT) { + auto clonedType = m_Builder.m_NodeCloner->CloneType(QT); + utils::ReferencesUpdater up(m_Sema, getCurrentScope(), m_Function); + up.updateType(clonedType); + return clonedType; + } Expr* VisitorBase::BuildOp(UnaryOperatorKind OpCode, Expr* E, SourceLocation OpLoc) { if (!E)