From 584eb31b7b2f4d2e47347a753710ad0b8093c8d2 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Wed, 25 Dec 2024 09:37:14 +0000 Subject: [PATCH] Improve derivative registration. Once the derivative's FunctionDecl is created it needs to be registered properly in the AST. Before this patch `registerDerivative` ignored many checks which make sure the produced code is valid C++. This patch enforces more rigorous checks for some declaration kinds such as standalone functions and some cases of class functions. Lambdas will be handled in subsequent patches. The challenge is to produce valid C++ code when differentiating classes. Consider: ```cpp class A1 { double f(double x) { return x * x; } double f_grad(double x, double *d_x); // forward declaration of the gradient. }; ``` In such cases clad will now produce a correct out-of-line definition of `f_grad`: `double A1::f_grad(double x, double *d_x) { ... }`. In cases where the gradient function is not declared as part of the class signature (external libraries such as STL), we continue to handle them as before. That will be dealt with in a subsequent patches. This also fixes assertions in debug mode when there are virtual functions. --- lib/Differentiator/BaseForwardModeVisitor.cpp | 28 ++++-- lib/Differentiator/DerivativeBuilder.cpp | 98 +++++++++---------- .../ReverseModeForwPassVisitor.cpp | 7 -- lib/Differentiator/ReverseModeVisitor.cpp | 30 ++++-- lib/Differentiator/VisitorBase.cpp | 5 +- test/FirstDerivative/ClassMethodCall.C | 6 +- test/FirstDerivative/VirtualMethodsCall.C | 21 ++-- test/Functors/Simple.C | 2 +- test/Gradient/Gradients.C | 2 +- test/Gradient/MemberFunctions.C | 48 ++++----- test/Misc/RunDemos.C | 2 - tools/ClangPlugin.cpp | 2 +- 12 files changed, 136 insertions(+), 115 deletions(-) diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index e54043cc6..3f8205ecb 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -201,16 +201,30 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() { addToCurrentBlock(BodyDiff); Stmt* derivativeBody = endBlock(); derivedFD->setBody(derivativeBody); - endScope(); // Function body scope + } - // Size >= current derivative order means that there exists a declaration - // or prototype for the currently derived function. - if (m_DiffReq.DerivedFDPrototypes.size() >= - m_DiffReq.CurrentDerivativeOrder) - m_Derivative->setPreviousDeclaration( - m_DiffReq.DerivedFDPrototypes[m_DiffReq.CurrentDerivativeOrder - 1]); + // FIXME: Drop the static specifier for the out-of-line definitions. + if (auto* RD = dyn_cast(m_Derivative->getDeclContext())) { + DeclContext::lookup_result R = + RD->getPrimaryContext()->lookup(m_Derivative->getDeclName()); + FunctionDecl* FoundFD = + R.empty() ? nullptr : dyn_cast(R.front()); + if (!RD->isLambda() && !R.empty() && + !m_Builder.m_DFC.IsCladDerivative(FoundFD)) { + Sema::NestedNameSpecInfo IdInfo(RD->getIdentifier(), noLoc, noLoc, + /*ObjectType=*/nullptr); + // FIXME: Address nested classes where SS should be set. + CXXScopeSpec SS; + m_Sema.BuildCXXNestedNameSpecifier(getCurrentScope(), IdInfo, + /*EnteringContext=*/true, SS, + /*ScopeLookupResult=*/nullptr, + /*ErrorRecoveryLookup=*/false); + m_Derivative->setQualifierInfo(SS.getWithLocInContext(m_Context)); + m_Derivative->setLexicalDeclContext(RD->getParent()); + } } + m_Sema.PopFunctionScopeInfo(); m_Sema.PopDeclContext(); endScope(); // Function decl scope diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 8f67ca588..19c62db5f 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -8,8 +8,12 @@ #include "JacobianModeVisitor.h" #include "clang/AST/ASTContext.h" +#include "clang/AST/Attr.h" +#include "clang/AST/Decl.h" #include "clang/AST/ExprCXX.h" #include "clang/AST/TemplateBase.h" +#include "clang/Basic/LLVM.h" // isa, dyn_cast +#include "clang/Basic/Specifiers.h" #include "clang/Sema/Lookup.h" #include "clang/Sema/Overload.h" #include "clang/Sema/Scope.h" @@ -50,48 +54,33 @@ DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, DerivativeBuilder::~DerivativeBuilder() {} -static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { - LookupResult R(semaRef, derivedFD->getNameInfo(), Sema::LookupOrdinaryName); - // FIXME: Attach out-of-line virtual function definitions to the TUScope. - Scope* S = semaRef.getScopeForContext(derivedFD->getDeclContext()); - semaRef.CheckFunctionDeclaration( - S, derivedFD, R, - /*IsMemberSpecialization=*/ - false - /*DeclIsDefn*/ CLAD_COMPAT_CheckFunctionDeclaration_DeclIsDefn_ExtraParam( - derivedFD)); - - // FIXME: Avoid the DeclContext lookup and the manual setPreviousDecl. - // Consider out-of-line virtual functions. - { - DeclContext* LookupCtx = derivedFD->getDeclContext(); - // Find the first non-transparent context to perform the lookup in. - while (LookupCtx->isTransparentContext()) - LookupCtx = LookupCtx->getParent(); - auto R = LookupCtx->noload_lookup(derivedFD->getDeclName()); - - for (NamedDecl* I : R) { - if (auto* FD = dyn_cast(I)) { - // FIXME: We still do extra work in creating a derivative and throwing - // it away. - if (FD->getDefinition()) - return; - - if (derivedFD->getASTContext().hasSameFunctionTypeIgnoringExceptionSpec( - derivedFD->getType(), FD->getType())) { - // Register the function on the redecl chain. - derivedFD->setPreviousDecl(FD); - break; - } - } - } - // Inform the decl's decl context for its existance after the lookup, - // otherwise it would end up in the LookupResult. - derivedFD->getDeclContext()->addDecl(derivedFD); - - // FIXME: Rebuild VTable to remove requirements for "forward" declared - // virtual methods +static void registerDerivative(FunctionDecl* dFD, Sema& S, + const DiffRequest& R) { + DeclContext* DC = dFD->getLexicalDeclContext(); + LookupResult Previous(S, dFD->getNameInfo(), Sema::LookupOrdinaryName); + S.LookupQualifiedName(Previous, dFD->getParent()); + + // Check if we created a top-level decl with the same name for another class. + // FIXME: This case should be addressed by providing proper names and function + // implementation that does not rely on accessing private data from the class. + bool IsBrokenDecl = isa(DC); + if (!IsBrokenDecl) { + S.CheckFunctionDeclaration( + /*Scope=*/nullptr, dFD, Previous, + /*IsMemberSpecialization=*/ + false + /*DeclIsDefn*/ + CLAD_COMPAT_CheckFunctionDeclaration_DeclIsDefn_ExtraParam(dFD)); + } else if (R.DerivedFDPrototypes.size() >= R.CurrentDerivativeOrder) { + // Size >= current derivative order means that there exists a declaration + // or prototype for the currently derived function. + dFD->setPreviousDecl(R.DerivedFDPrototypes[R.CurrentDerivativeOrder - 1]); } + + if (dFD->isInvalidDecl()) + return; // CheckFunctionDeclaration was unhappy about derivedFD + + DC->addDecl(dFD); } static bool hasAttribute(const Decl *D, attr::Kind Kind) { @@ -107,33 +96,42 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { clang::DeclarationNameInfo name, clang::QualType functionType) { FunctionDecl* returnedFD = nullptr; NamespaceDecl* enclosingNS = nullptr; + TypeSourceInfo* TSI = m_Context.getTrivialTypeSourceInfo(functionType); if (isa(FD)) { CXXRecordDecl* CXXRD = cast(DC); returnedFD = CXXMethodDecl::Create( - m_Context, CXXRD, noLoc, name, functionType, FD->getTypeSourceInfo(), + m_Context, CXXRD, noLoc, name, functionType, TSI, FD->getCanonicalDecl()->getStorageClass() CLAD_COMPAT_FunctionDecl_UsesFPIntrin_Param(FD), FD->isInlineSpecified(), clad_compat::Function_GetConstexprKind(FD), noLoc); - returnedFD->setAccess(FD->getAccess()); + // Generated member function should be called outside of class definitions + // even if their original function had different access specifier. + returnedFD->setAccess(AS_public); } else { assert (isa(FD) && "Unexpected!"); enclosingNS = VB.RebuildEnclosingNamespaces(DC); returnedFD = FunctionDecl::Create( - m_Context, m_Sema.CurContext, noLoc, name, functionType, - FD->getTypeSourceInfo(), + m_Context, m_Sema.CurContext, noLoc, name, functionType, TSI, FD->getCanonicalDecl()->getStorageClass() CLAD_COMPAT_FunctionDecl_UsesFPIntrin_Param(FD), FD->isInlineSpecified(), FD->hasWrittenPrototype(), clad_compat::Function_GetConstexprKind(FD) CLAD_COMPAT_CLANG10_FunctionDecl_Create_ExtraParams( FD->getTrailingRequiresClause())); - } - for (const FunctionDecl* NFD : FD->redecls()) - for (const auto* Attr : NFD->attrs()) + returnedFD->setAccess(FD->getAccess()); + } + + for (const FunctionDecl* NFD : FD->redecls()) { + for (const auto* Attr : NFD->attrs()) { + // We only need the keywords final and override in the tag declaration. + if (isa(Attr) || isa(Attr)) + continue; if (!hasAttribute(returnedFD, Attr->getKind())) returnedFD->addAttr(Attr->clone(m_Context)); + } + } return { returnedFD, enclosingNS }; } @@ -475,9 +473,9 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { // derivative is a member function it goes into an infinite loop if (!m_DFC.IsCustomDerivative(result.derivative)) { if (auto* FD = result.derivative) - registerDerivative(FD, m_Sema); + registerDerivative(FD, m_Sema, request); if (auto* OFD = result.overload) - registerDerivative(OFD, m_Sema); + registerDerivative(OFD, m_Sema, request); } return result; diff --git a/lib/Differentiator/ReverseModeForwPassVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp index 8c1a4df68..a6076f1fd 100644 --- a/lib/Differentiator/ReverseModeForwPassVisitor.cpp +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -82,13 +82,6 @@ DerivativeAndOverload ReverseModeForwPassVisitor::Derive() { Stmt* fnBody = endBlock(); m_Derivative->setBody(fnBody); endScope(); - - // Size >= current derivative order means that there exists a declaration - // or prototype for the currently derived function. - if (m_DiffReq.DerivedFDPrototypes.size() >= - m_DiffReq.CurrentDerivativeOrder) - m_Derivative->setPreviousDeclaration( - m_DiffReq.DerivedFDPrototypes[m_DiffReq.CurrentDerivativeOrder - 1]); } m_Sema.PopFunctionScopeInfo(); m_Sema.PopDeclContext(); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index cdcc64b74..3d4b5245a 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -298,20 +298,34 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Stmt* fnBody = endBlock(); m_Derivative->setBody(fnBody); - endScope(); // Function body scope - // Size >= current derivative order means that there exists a declaration - // or prototype for the currently derived function. - if (m_DiffReq.DerivedFDPrototypes.size() >= - m_DiffReq.CurrentDerivativeOrder) - m_Derivative->setPreviousDeclaration( - m_DiffReq - .DerivedFDPrototypes[m_DiffReq.CurrentDerivativeOrder - 1]); + endScope(); // Function body scope } + m_Sema.PopFunctionScopeInfo(); m_Sema.PopDeclContext(); endScope(); // Function decl scope + if (auto* RD = dyn_cast(m_Derivative->getDeclContext())) { + DeclContext::lookup_result R = + RD->getPrimaryContext()->lookup(m_Derivative->getDeclName()); + FunctionDecl* FoundFD = + R.empty() ? nullptr : dyn_cast(R.front()); + if (!RD->isLambda() && !R.empty() && + !m_Builder.m_DFC.IsCladDerivative(FoundFD)) { + Sema::NestedNameSpecInfo IdInfo(RD->getIdentifier(), noLoc, noLoc, + /*ObjectType=*/nullptr); + // FIXME: Address nested classes where SS should be set. + CXXScopeSpec SS; + m_Sema.BuildCXXNestedNameSpecifier(getCurrentScope(), IdInfo, + /*EnteringContext=*/true, SS, + /*ScopeLookupResult=*/nullptr, + /*ErrorRecoveryLookup=*/false); + m_Derivative->setQualifierInfo(SS.getWithLocInContext(m_Context)); + m_Derivative->setLexicalDeclContext(RD->getParent()); + } + } + if (!shouldCreateOverload) return DerivativeAndOverload{result.first, /*overload=*/nullptr}; diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 408dd9f91..2f00a2847 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -666,8 +666,9 @@ namespace clad { .get(); isArrow = false; } - NestedNameSpecifierLoc NNS(FD->getQualifier(), - /*Data=*/nullptr); + // Leads to printing this->Class::Function(x, y). + // FIXME: Enable for static functions. + NestedNameSpecifierLoc NNS /* = FD->getQualifierLoc()*/; auto DAP = DeclAccessPair::make(FD, FD->getAccess()); auto* memberExpr = MemberExpr::Create( m_Context, thisExpr, isArrow, Loc, NNS, noLoc, FD, DAP, diff --git a/test/FirstDerivative/ClassMethodCall.C b/test/FirstDerivative/ClassMethodCall.C index 14c83c5e2..02208ae88 100644 --- a/test/FirstDerivative/ClassMethodCall.C +++ b/test/FirstDerivative/ClassMethodCall.C @@ -72,7 +72,7 @@ public: } float vm_darg0(float x, float y); - //CHECK: float vm_darg0(float x, float y) { + //CHECK: float A::vm_darg0(float x, float y) { //CHECK-NEXT: float _d_x = 1; //CHECK-NEXT: float _d_y = 0; //CHECK-NEXT: A _d_this_obj; @@ -91,7 +91,7 @@ public: } float vm_darg0(float x, float y); - //CHECK: float vm_darg0(float x, float y) override { + //CHECK: float B::vm_darg0(float x, float y) { //CHECK-NEXT: float _d_x = 1; //CHECK-NEXT: float _d_y = 0; //CHECK-NEXT: B _d_this_obj; @@ -117,7 +117,7 @@ int main () { auto vm_darg0_B = clad::differentiate(&B::vm, 0); printf("Result is = %f\n", vm_darg0_B.execute(b, 2, 3)); // CHECK-EXEC: Result is = 4.0000 printf("%s\n", vm_darg0_B.getCode()); - //CHECK-EXEC: float vm_darg0(float x, float y) override { + //CHECK-EXEC: float B::vm_darg0(float x, float y) { //CHECK-EXEC-NEXT: float _d_x = 1; //CHECK-EXEC-NEXT: float _d_y = 0; //CHECK-EXEC-NEXT: B _d_this_obj; diff --git a/test/FirstDerivative/VirtualMethodsCall.C b/test/FirstDerivative/VirtualMethodsCall.C index ea7c45199..991a2aba0 100644 --- a/test/FirstDerivative/VirtualMethodsCall.C +++ b/test/FirstDerivative/VirtualMethodsCall.C @@ -1,6 +1,5 @@ // RUN: %cladclang %s -I%S/../../include -oVirtualMethodsCall.out 2>&1 | %filecheck %s // RUN: ./VirtualMethodsCall.out | %filecheck_exec %s -// XFAIL: asserts #include "clad/Differentiator/Differentiator.h" @@ -41,7 +40,7 @@ public: // TODO: Remove call forward when execute works with polymorphic methods auto vm_darg0_cf = clad::differentiate(&A::vm, 0); // FIXME: We need to make this out-of-line -//CHECK: float vm_darg0(float x, float y) { +//CHECK: float A::vm_darg0(float x, float y) { //CHECK-NEXT: float _d_x = 1; //CHECK-NEXT: float _d_y = 0; //CHECK-NEXT: A _d_this_obj; @@ -49,7 +48,7 @@ public: //CHECK-NEXT: return _d_x + _d_y; //CHECK-NEXT: } auto vm_darg1_cf = clad::differentiate(&A::vm, 1); -//CHECK: float vm_darg1(float x, float y) { +//CHECK: float A::vm_darg1(float x, float y) { //CHECK-NEXT: float _d_x = 0; //CHECK-NEXT: float _d_y = 1; //CHECK-NEXT: A _d_this_obj; @@ -130,6 +129,9 @@ public: return x*y + x*y; } + float vm_darg0(float x, float y) override; // forward + float vm_darg1(float x, float y) override; // forward + // Inherited from A // float vm1(float x, float y)... @@ -180,7 +182,8 @@ int main () { printf("---\n"); // CHECK-EXEC: --- auto vm_darg0_A = clad::differentiate((float(A::*)(float,float))&A::vm, 0); -//CHECK: float vm_darg0(float x, float y) override { + //FIXME: This should select the implementation of vm in A! +//CHECK: float B::vm_darg0(float x, float y) { //CHECK-NEXT: float _d_x = 1; //CHECK-NEXT: float _d_y = 0; // CHECK-NEXT: B _d_this_obj; @@ -188,7 +191,7 @@ int main () { //CHECK-NEXT: return _d_x * x + x * _d_x + _d_y * y + y * _d_y; //CHECK-NEXT: } auto vm_darg0_B = clad::differentiate((float(B::*)(float,float))&B::vm, 0); -//CHECK: float vm_darg1(float x, float y) override { +//CHECK: float B::vm_darg1(float x, float y) { //CHECK-NEXT: float _d_x = 0; //CHECK-NEXT: float _d_y = 1; // CHECK-NEXT: B _d_this_obj; @@ -208,7 +211,7 @@ int main () { //CHECK-NEXT: return _d_x - _d_y; //CHECK-NEXT: } auto vm1_darg0_B = clad::differentiate((float(B::*)(float,float))&B::vm1, 0); -//CHECK: float vm1_darg0(float x, float y) override { +//CHECK: float vm1_darg0(float x, float y) { //CHECK-NEXT: float _d_x = 1; //CHECK-NEXT: float _d_y = 0; // CHECK-NEXT: B _d_this_obj; @@ -224,7 +227,7 @@ int main () { //CHECK-NEXT: return _d_x - _d_y; //CHECK-NEXT: } auto vm1_darg1_B = clad::differentiate((float(B::*)(float,float))&B::vm1, 1); -//CHECK: float vm1_darg1(float x, float y) override { +//CHECK: float vm1_darg1(float x, float y) { //CHECK-NEXT: float _d_x = 0; //CHECK-NEXT: float _d_y = 1; // CHECK-NEXT: B _d_this_obj; @@ -255,7 +258,7 @@ int main () { printf("---\n"); // CHECK-EXEC: --- auto vm_darg0_B1 = clad::differentiate(&B1::vm, 0); -//CHECK: float vm_darg0(float x, float y) override { +//CHECK: float B1::vm_darg0(float x, float y) { //CHECK-NEXT: float _d_x = 1; //CHECK-NEXT: float _d_y = 0; // CHECK-NEXT: B1 _d_this_obj; @@ -263,7 +266,7 @@ int main () { //CHECK-NEXT: return _d_x * y + x * _d_y + _d_x * y + x * _d_y; //CHECK-NEXT: } auto vm_darg1_B1 = clad::differentiate(&B1::vm, 1); -//CHECK: float vm_darg1(float x, float y) override { +//CHECK: float B1::vm_darg1(float x, float y) { //CHECK-NEXT: float _d_x = 0; //CHECK-NEXT: float _d_y = 1; // CHECK-NEXT: B1 _d_this_obj; diff --git a/test/Functors/Simple.C b/test/Functors/Simple.C index cbd2bb804..a6a11ca71 100644 --- a/test/Functors/Simple.C +++ b/test/Functors/Simple.C @@ -32,7 +32,7 @@ public: float operator_call_darg0(float x, float y); }; -// CHECK: float operator_call_darg0(float x, float y) { +// CHECK: float SimpleExpression::operator_call_darg0(float x, float y) { // CHECK-NEXT: float _d_x = 1; // CHECK-NEXT: float _d_y = 0; // CHECK-NEXT: SimpleExpression _d_this_obj; diff --git a/test/Gradient/Gradients.C b/test/Gradient/Gradients.C index e2a4f172f..949b81ab1 100644 --- a/test/Gradient/Gradients.C +++ b/test/Gradient/Gradients.C @@ -353,7 +353,7 @@ struct S { return c1 * x + c2 * y; } - //CHECK: void f_grad(double x, double y, S *_d_this, double *_d_x, double *_d_y) { + //CHECK: void S::f_grad(double x, double y, S *_d_this, double *_d_x, double *_d_y) { //CHECK-NEXT: { //CHECK-NEXT: (*_d_this).c1 += 1 * x; //CHECK-NEXT: *_d_x += this->c1 * 1; diff --git a/test/Gradient/MemberFunctions.C b/test/Gradient/MemberFunctions.C index d704c67ea..d442270c7 100644 --- a/test/Gradient/MemberFunctions.C +++ b/test/Gradient/MemberFunctions.C @@ -21,7 +21,7 @@ public: double x, y; double mem_fn(double i, double j) { return (x + y) * i + i * j; } - // CHECK: void mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) { + // CHECK: void SimpleFunctions::mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) { // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; // CHECK-NEXT: (*_d_this).y += 1 * i; @@ -33,7 +33,7 @@ public: double const_mem_fn(double i, double j) const { return (x + y) * i + i * j; } - // CHECK: void const_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) const { + // CHECK: void SimpleFunctions::const_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) const { // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; // CHECK-NEXT: (*_d_this).y += 1 * i; @@ -47,7 +47,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void volatile_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) volatile { + // CHECK: void SimpleFunctions::volatile_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) volatile { // CHECK-NEXT: double _t0 = (this->x + this->y); // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; @@ -62,7 +62,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void const_volatile_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) const volatile { + // CHECK: void SimpleFunctions::const_volatile_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) const volatile { // CHECK-NEXT: double _t0 = (this->x + this->y); // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; @@ -75,7 +75,7 @@ public: double lval_ref_mem_fn(double i, double j) & { return (x + y) * i + i * j; } - // CHECK: void lval_ref_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) & { + // CHECK: void SimpleFunctions::lval_ref_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) & { // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; // CHECK-NEXT: (*_d_this).y += 1 * i; @@ -89,7 +89,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void const_lval_ref_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) const & { + // CHECK: void SimpleFunctions::const_lval_ref_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) const & { // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; // CHECK-NEXT: (*_d_this).y += 1 * i; @@ -103,7 +103,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void volatile_lval_ref_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) volatile & { + // CHECK: void SimpleFunctions::volatile_lval_ref_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) volatile & { // CHECK-NEXT: double _t0 = (this->x + this->y); // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; @@ -118,7 +118,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void const_volatile_lval_ref_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) const volatile & { + // CHECK: void SimpleFunctions::const_volatile_lval_ref_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) const volatile & { // CHECK-NEXT: double _t0 = (this->x + this->y); // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; @@ -131,7 +131,7 @@ public: double rval_ref_mem_fn(double i, double j) && { return (x + y) * i + i * j; } - // CHECK: void rval_ref_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) && { + // CHECK: void SimpleFunctions::rval_ref_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) && { // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; // CHECK-NEXT: (*_d_this).y += 1 * i; @@ -145,7 +145,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void const_rval_ref_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) const && { + // CHECK: void SimpleFunctions::const_rval_ref_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) const && { // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; // CHECK-NEXT: (*_d_this).y += 1 * i; @@ -159,7 +159,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void volatile_rval_ref_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) volatile && { + // CHECK: void SimpleFunctions::volatile_rval_ref_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) volatile && { // CHECK-NEXT: double _t0 = (this->x + this->y); // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; @@ -174,7 +174,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void const_volatile_rval_ref_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) const volatile && { + // CHECK: void SimpleFunctions::const_volatile_rval_ref_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) const volatile && { // CHECK-NEXT: double _t0 = (this->x + this->y); // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; @@ -189,7 +189,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void noexcept_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) noexcept { + // CHECK: void SimpleFunctions::noexcept_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) noexcept { // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; // CHECK-NEXT: (*_d_this).y += 1 * i; @@ -203,7 +203,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void const_noexcept_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) const noexcept { + // CHECK: void SimpleFunctions::const_noexcept_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) const noexcept { // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; // CHECK-NEXT: (*_d_this).y += 1 * i; @@ -217,7 +217,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void volatile_noexcept_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) volatile noexcept { + // CHECK: void SimpleFunctions::volatile_noexcept_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) volatile noexcept { // CHECK-NEXT: double _t0 = (this->x + this->y); // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; @@ -232,7 +232,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void const_volatile_noexcept_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) const volatile noexcept { + // CHECK: void SimpleFunctions::const_volatile_noexcept_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) const volatile noexcept { // CHECK-NEXT: double _t0 = (this->x + this->y); // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; @@ -247,7 +247,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void lval_ref_noexcept_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) & noexcept { + // CHECK: void SimpleFunctions::lval_ref_noexcept_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) & noexcept { // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; // CHECK-NEXT: (*_d_this).y += 1 * i; @@ -261,7 +261,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void const_lval_ref_noexcept_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) const & noexcept { + // CHECK: void SimpleFunctions::const_lval_ref_noexcept_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) const & noexcept { // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; // CHECK-NEXT: (*_d_this).y += 1 * i; @@ -275,7 +275,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void volatile_lval_ref_noexcept_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) volatile & noexcept { + // CHECK: void SimpleFunctions::volatile_lval_ref_noexcept_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) volatile & noexcept { // CHECK-NEXT: double _t0 = (this->x + this->y); // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; @@ -290,7 +290,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void const_volatile_lval_ref_noexcept_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) const volatile & noexcept { + // CHECK: void SimpleFunctions::const_volatile_lval_ref_noexcept_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) const volatile & noexcept { // CHECK-NEXT: double _t0 = (this->x + this->y); // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; @@ -305,7 +305,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void rval_ref_noexcept_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) && noexcept { + // CHECK: void SimpleFunctions::rval_ref_noexcept_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) && noexcept { // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; // CHECK-NEXT: (*_d_this).y += 1 * i; @@ -319,7 +319,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void const_rval_ref_noexcept_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) const && noexcept { + // CHECK: void SimpleFunctions::const_rval_ref_noexcept_mem_fn_grad(double i, double j, SimpleFunctions *_d_this, double *_d_i, double *_d_j) const && noexcept { // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; // CHECK-NEXT: (*_d_this).y += 1 * i; @@ -333,7 +333,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void volatile_rval_ref_noexcept_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) volatile && noexcept { + // CHECK: void SimpleFunctions::volatile_rval_ref_noexcept_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) volatile && noexcept { // CHECK-NEXT: double _t0 = (this->x + this->y); // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; @@ -348,7 +348,7 @@ public: return (x+y)*i + i*j; } - // CHECK: void const_volatile_rval_ref_noexcept_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) const volatile && noexcept { + // CHECK: void SimpleFunctions::const_volatile_rval_ref_noexcept_mem_fn_grad(double i, double j, volatile SimpleFunctions *_d_this, double *_d_i, double *_d_j) const volatile && noexcept { // CHECK-NEXT: double _t0 = (this->x + this->y); // CHECK-NEXT: { // CHECK-NEXT: (*_d_this).x += 1 * i; diff --git a/test/Misc/RunDemos.C b/test/Misc/RunDemos.C index 377fe403d..64b1931ed 100644 --- a/test/Misc/RunDemos.C +++ b/test/Misc/RunDemos.C @@ -4,8 +4,6 @@ // RUN: %cladclang %S/../../demos/RosenbrockFunction.cpp -I%S/../../include 2>&1 // RUN: %cladclang %S/../../demos/ComputerGraphics/smallpt/SmallPT.cpp -I%S/../../include 2>&1 -// XFAIL: asserts - //-----------------------------------------------------------------------------/ // Demo: Gradient.cpp //-----------------------------------------------------------------------------/ diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index ee11c04bf..30a200abf 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -286,7 +286,7 @@ namespace clad { // decl or is contained in a namespace decl. // FIXME: We could get rid of this by prepending the produced // derivatives in CladPlugin::HandleTranslationUnitDecl - DeclContext* derivativeDC = DerivativeDecl->getDeclContext(); + DeclContext* derivativeDC = DerivativeDecl->getLexicalDeclContext(); bool isTUorND = derivativeDC->isTranslationUnit() || derivativeDC->isNamespace(); if (isTUorND) {