Skip to content

Commit

Permalink
Improve derivative registration.
Browse files Browse the repository at this point in the history
Once the derivative's FunctionDecl is created it needs to be registered properly
in the AST. Before this patch `registerDerivative` ignored many checks which
make sure the produced code is valid C++. This patch enforces more rigorous
checks for some declaration kinds such as standalone functions and some cases
of class functions. Lambdas will be handled in subsequent patches.

The challenge is to produce valid C++ code when differentiating classes.
Consider:
```cpp
class A1 {
  double f(double x) { return x * x; }
  double f_grad(double x, double *d_x); // forward declaration of the gradient.
};
```

In such cases clad will now produce a correct out-of-line definition of `f_grad`:
`double A1::f_grad(double x, double *d_x) { ... }`.

In cases where the gradient function is not declared as part of the class
signature (external libraries such as STL), we continue to handle them as
before. That will be dealt with in a subsequent patches.

This also fixes assertions in debug mode when there are virtual functions.
  • Loading branch information
vgvassilev committed Jan 7, 2025
1 parent c2f1acf commit 584eb31
Show file tree
Hide file tree
Showing 12 changed files with 136 additions and 115 deletions.
28 changes: 21 additions & 7 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,16 +201,30 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
addToCurrentBlock(BodyDiff);
Stmt* derivativeBody = endBlock();
derivedFD->setBody(derivativeBody);

endScope(); // Function body scope
}

// Size >= current derivative order means that there exists a declaration
// or prototype for the currently derived function.
if (m_DiffReq.DerivedFDPrototypes.size() >=
m_DiffReq.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
m_DiffReq.DerivedFDPrototypes[m_DiffReq.CurrentDerivativeOrder - 1]);
// FIXME: Drop the static specifier for the out-of-line definitions.
if (auto* RD = dyn_cast<RecordDecl>(m_Derivative->getDeclContext())) {
DeclContext::lookup_result R =
RD->getPrimaryContext()->lookup(m_Derivative->getDeclName());
FunctionDecl* FoundFD =
R.empty() ? nullptr : dyn_cast<FunctionDecl>(R.front());
if (!RD->isLambda() && !R.empty() &&
!m_Builder.m_DFC.IsCladDerivative(FoundFD)) {
Sema::NestedNameSpecInfo IdInfo(RD->getIdentifier(), noLoc, noLoc,
/*ObjectType=*/nullptr);
// FIXME: Address nested classes where SS should be set.
CXXScopeSpec SS;
m_Sema.BuildCXXNestedNameSpecifier(getCurrentScope(), IdInfo,
/*EnteringContext=*/true, SS,
/*ScopeLookupResult=*/nullptr,
/*ErrorRecoveryLookup=*/false);
m_Derivative->setQualifierInfo(SS.getWithLocInContext(m_Context));
m_Derivative->setLexicalDeclContext(RD->getParent());
}
}

m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope(); // Function decl scope
Expand Down
98 changes: 48 additions & 50 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@

#include "JacobianModeVisitor.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/Attr.h"
#include "clang/AST/Decl.h"
#include "clang/AST/ExprCXX.h"
#include "clang/AST/TemplateBase.h"
#include "clang/Basic/LLVM.h" // isa, dyn_cast
#include "clang/Basic/Specifiers.h"
#include "clang/Sema/Lookup.h"
#include "clang/Sema/Overload.h"
#include "clang/Sema/Scope.h"
Expand Down Expand Up @@ -50,48 +54,33 @@ DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P,

DerivativeBuilder::~DerivativeBuilder() {}

static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
LookupResult R(semaRef, derivedFD->getNameInfo(), Sema::LookupOrdinaryName);
// FIXME: Attach out-of-line virtual function definitions to the TUScope.
Scope* S = semaRef.getScopeForContext(derivedFD->getDeclContext());
semaRef.CheckFunctionDeclaration(
S, derivedFD, R,
/*IsMemberSpecialization=*/
false
/*DeclIsDefn*/ CLAD_COMPAT_CheckFunctionDeclaration_DeclIsDefn_ExtraParam(
derivedFD));

// FIXME: Avoid the DeclContext lookup and the manual setPreviousDecl.
// Consider out-of-line virtual functions.
{
DeclContext* LookupCtx = derivedFD->getDeclContext();
// Find the first non-transparent context to perform the lookup in.
while (LookupCtx->isTransparentContext())
LookupCtx = LookupCtx->getParent();
auto R = LookupCtx->noload_lookup(derivedFD->getDeclName());

for (NamedDecl* I : R) {
if (auto* FD = dyn_cast<FunctionDecl>(I)) {
// FIXME: We still do extra work in creating a derivative and throwing
// it away.
if (FD->getDefinition())
return;

if (derivedFD->getASTContext().hasSameFunctionTypeIgnoringExceptionSpec(
derivedFD->getType(), FD->getType())) {
// Register the function on the redecl chain.
derivedFD->setPreviousDecl(FD);
break;
}
}
}
// Inform the decl's decl context for its existance after the lookup,
// otherwise it would end up in the LookupResult.
derivedFD->getDeclContext()->addDecl(derivedFD);

// FIXME: Rebuild VTable to remove requirements for "forward" declared
// virtual methods
static void registerDerivative(FunctionDecl* dFD, Sema& S,
const DiffRequest& R) {
DeclContext* DC = dFD->getLexicalDeclContext();
LookupResult Previous(S, dFD->getNameInfo(), Sema::LookupOrdinaryName);
S.LookupQualifiedName(Previous, dFD->getParent());

// Check if we created a top-level decl with the same name for another class.
// FIXME: This case should be addressed by providing proper names and function
// implementation that does not rely on accessing private data from the class.
bool IsBrokenDecl = isa<RecordDecl>(DC);
if (!IsBrokenDecl) {
S.CheckFunctionDeclaration(
/*Scope=*/nullptr, dFD, Previous,
/*IsMemberSpecialization=*/
false
/*DeclIsDefn*/
CLAD_COMPAT_CheckFunctionDeclaration_DeclIsDefn_ExtraParam(dFD));
} else if (R.DerivedFDPrototypes.size() >= R.CurrentDerivativeOrder) {
// Size >= current derivative order means that there exists a declaration
// or prototype for the currently derived function.
dFD->setPreviousDecl(R.DerivedFDPrototypes[R.CurrentDerivativeOrder - 1]);
}

if (dFD->isInvalidDecl())
return; // CheckFunctionDeclaration was unhappy about derivedFD

DC->addDecl(dFD);
}

static bool hasAttribute(const Decl *D, attr::Kind Kind) {
Expand All @@ -107,33 +96,42 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
clang::DeclarationNameInfo name, clang::QualType functionType) {
FunctionDecl* returnedFD = nullptr;
NamespaceDecl* enclosingNS = nullptr;
TypeSourceInfo* TSI = m_Context.getTrivialTypeSourceInfo(functionType);
if (isa<CXXMethodDecl>(FD)) {
CXXRecordDecl* CXXRD = cast<CXXRecordDecl>(DC);
returnedFD = CXXMethodDecl::Create(
m_Context, CXXRD, noLoc, name, functionType, FD->getTypeSourceInfo(),
m_Context, CXXRD, noLoc, name, functionType, TSI,
FD->getCanonicalDecl()->getStorageClass()
CLAD_COMPAT_FunctionDecl_UsesFPIntrin_Param(FD),
FD->isInlineSpecified(), clad_compat::Function_GetConstexprKind(FD),
noLoc);
returnedFD->setAccess(FD->getAccess());
// Generated member function should be called outside of class definitions
// even if their original function had different access specifier.
returnedFD->setAccess(AS_public);
} else {
assert (isa<FunctionDecl>(FD) && "Unexpected!");
enclosingNS = VB.RebuildEnclosingNamespaces(DC);
returnedFD = FunctionDecl::Create(
m_Context, m_Sema.CurContext, noLoc, name, functionType,
FD->getTypeSourceInfo(),
m_Context, m_Sema.CurContext, noLoc, name, functionType, TSI,
FD->getCanonicalDecl()->getStorageClass()
CLAD_COMPAT_FunctionDecl_UsesFPIntrin_Param(FD),
FD->isInlineSpecified(), FD->hasWrittenPrototype(),
clad_compat::Function_GetConstexprKind(FD)
CLAD_COMPAT_CLANG10_FunctionDecl_Create_ExtraParams(
FD->getTrailingRequiresClause()));
}

for (const FunctionDecl* NFD : FD->redecls())
for (const auto* Attr : NFD->attrs())
returnedFD->setAccess(FD->getAccess());
}

for (const FunctionDecl* NFD : FD->redecls()) {
for (const auto* Attr : NFD->attrs()) {
// We only need the keywords final and override in the tag declaration.
if (isa<OverrideAttr>(Attr) || isa<FinalAttr>(Attr))
continue;
if (!hasAttribute(returnedFD, Attr->getKind()))
returnedFD->addAttr(Attr->clone(m_Context));
}
}

return { returnedFD, enclosingNS };
}
Expand Down Expand Up @@ -475,9 +473,9 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
// derivative is a member function it goes into an infinite loop
if (!m_DFC.IsCustomDerivative(result.derivative)) {
if (auto* FD = result.derivative)
registerDerivative(FD, m_Sema);
registerDerivative(FD, m_Sema, request);
if (auto* OFD = result.overload)
registerDerivative(OFD, m_Sema);
registerDerivative(OFD, m_Sema, request);
}

return result;
Expand Down
7 changes: 0 additions & 7 deletions lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,6 @@ DerivativeAndOverload ReverseModeForwPassVisitor::Derive() {
Stmt* fnBody = endBlock();
m_Derivative->setBody(fnBody);
endScope();

// Size >= current derivative order means that there exists a declaration
// or prototype for the currently derived function.
if (m_DiffReq.DerivedFDPrototypes.size() >=
m_DiffReq.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
m_DiffReq.DerivedFDPrototypes[m_DiffReq.CurrentDerivativeOrder - 1]);
}
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
Expand Down
30 changes: 22 additions & 8 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,20 +298,34 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

Stmt* fnBody = endBlock();
m_Derivative->setBody(fnBody);
endScope(); // Function body scope

// Size >= current derivative order means that there exists a declaration
// or prototype for the currently derived function.
if (m_DiffReq.DerivedFDPrototypes.size() >=
m_DiffReq.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
m_DiffReq
.DerivedFDPrototypes[m_DiffReq.CurrentDerivativeOrder - 1]);
endScope(); // Function body scope
}

m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope(); // Function decl scope

if (auto* RD = dyn_cast<RecordDecl>(m_Derivative->getDeclContext())) {
DeclContext::lookup_result R =
RD->getPrimaryContext()->lookup(m_Derivative->getDeclName());
FunctionDecl* FoundFD =
R.empty() ? nullptr : dyn_cast<FunctionDecl>(R.front());
if (!RD->isLambda() && !R.empty() &&
!m_Builder.m_DFC.IsCladDerivative(FoundFD)) {
Sema::NestedNameSpecInfo IdInfo(RD->getIdentifier(), noLoc, noLoc,
/*ObjectType=*/nullptr);
// FIXME: Address nested classes where SS should be set.
CXXScopeSpec SS;
m_Sema.BuildCXXNestedNameSpecifier(getCurrentScope(), IdInfo,
/*EnteringContext=*/true, SS,
/*ScopeLookupResult=*/nullptr,
/*ErrorRecoveryLookup=*/false);
m_Derivative->setQualifierInfo(SS.getWithLocInContext(m_Context));
m_Derivative->setLexicalDeclContext(RD->getParent());
}
}

if (!shouldCreateOverload)
return DerivativeAndOverload{result.first, /*overload=*/nullptr};

Expand Down
5 changes: 3 additions & 2 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,9 @@ namespace clad {
.get();
isArrow = false;
}
NestedNameSpecifierLoc NNS(FD->getQualifier(),
/*Data=*/nullptr);
// Leads to printing this->Class::Function(x, y).
// FIXME: Enable for static functions.
NestedNameSpecifierLoc NNS /* = FD->getQualifierLoc()*/;
auto DAP = DeclAccessPair::make(FD, FD->getAccess());
auto* memberExpr = MemberExpr::Create(
m_Context, thisExpr, isArrow, Loc, NNS, noLoc, FD, DAP,
Expand Down
6 changes: 3 additions & 3 deletions test/FirstDerivative/ClassMethodCall.C
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public:
}

float vm_darg0(float x, float y);
//CHECK: float vm_darg0(float x, float y) {
//CHECK: float A::vm_darg0(float x, float y) {
//CHECK-NEXT: float _d_x = 1;
//CHECK-NEXT: float _d_y = 0;
//CHECK-NEXT: A _d_this_obj;
Expand All @@ -91,7 +91,7 @@ public:
}

float vm_darg0(float x, float y);
//CHECK: float vm_darg0(float x, float y) override {
//CHECK: float B::vm_darg0(float x, float y) {
//CHECK-NEXT: float _d_x = 1;
//CHECK-NEXT: float _d_y = 0;
//CHECK-NEXT: B _d_this_obj;
Expand All @@ -117,7 +117,7 @@ int main () {
auto vm_darg0_B = clad::differentiate(&B::vm, 0);
printf("Result is = %f\n", vm_darg0_B.execute(b, 2, 3)); // CHECK-EXEC: Result is = 4.0000
printf("%s\n", vm_darg0_B.getCode());
//CHECK-EXEC: float vm_darg0(float x, float y) override {
//CHECK-EXEC: float B::vm_darg0(float x, float y) {
//CHECK-EXEC-NEXT: float _d_x = 1;
//CHECK-EXEC-NEXT: float _d_y = 0;
//CHECK-EXEC-NEXT: B _d_this_obj;
Expand Down
21 changes: 12 additions & 9 deletions test/FirstDerivative/VirtualMethodsCall.C
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// RUN: %cladclang %s -I%S/../../include -oVirtualMethodsCall.out 2>&1 | %filecheck %s
// RUN: ./VirtualMethodsCall.out | %filecheck_exec %s
// XFAIL: asserts

#include "clad/Differentiator/Differentiator.h"

Expand Down Expand Up @@ -41,15 +40,15 @@ public:
// TODO: Remove call forward when execute works with polymorphic methods
auto vm_darg0_cf = clad::differentiate(&A::vm, 0);
// FIXME: We need to make this out-of-line
//CHECK: float vm_darg0(float x, float y) {
//CHECK: float A::vm_darg0(float x, float y) {
//CHECK-NEXT: float _d_x = 1;
//CHECK-NEXT: float _d_y = 0;
//CHECK-NEXT: A _d_this_obj;
//CHECK-NEXT: A *_d_this = &_d_this_obj;
//CHECK-NEXT: return _d_x + _d_y;
//CHECK-NEXT: }
auto vm_darg1_cf = clad::differentiate(&A::vm, 1);
//CHECK: float vm_darg1(float x, float y) {
//CHECK: float A::vm_darg1(float x, float y) {
//CHECK-NEXT: float _d_x = 0;
//CHECK-NEXT: float _d_y = 1;
//CHECK-NEXT: A _d_this_obj;
Expand Down Expand Up @@ -130,6 +129,9 @@ public:
return x*y + x*y;
}

float vm_darg0(float x, float y) override; // forward
float vm_darg1(float x, float y) override; // forward

// Inherited from A
// float vm1(float x, float y)...

Expand Down Expand Up @@ -180,15 +182,16 @@ int main () {
printf("---\n"); // CHECK-EXEC: ---

auto vm_darg0_A = clad::differentiate((float(A::*)(float,float))&A::vm, 0);
//CHECK: float vm_darg0(float x, float y) override {
//FIXME: This should select the implementation of vm in A!
//CHECK: float B::vm_darg0(float x, float y) {
//CHECK-NEXT: float _d_x = 1;
//CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: B _d_this_obj;
// CHECK-NEXT: B *_d_this = &_d_this_obj;
//CHECK-NEXT: return _d_x * x + x * _d_x + _d_y * y + y * _d_y;
//CHECK-NEXT: }
auto vm_darg0_B = clad::differentiate((float(B::*)(float,float))&B::vm, 0);
//CHECK: float vm_darg1(float x, float y) override {
//CHECK: float B::vm_darg1(float x, float y) {
//CHECK-NEXT: float _d_x = 0;
//CHECK-NEXT: float _d_y = 1;
// CHECK-NEXT: B _d_this_obj;
Expand All @@ -208,7 +211,7 @@ int main () {
//CHECK-NEXT: return _d_x - _d_y;
//CHECK-NEXT: }
auto vm1_darg0_B = clad::differentiate((float(B::*)(float,float))&B::vm1, 0);
//CHECK: float vm1_darg0(float x, float y) override {
//CHECK: float vm1_darg0(float x, float y) {
//CHECK-NEXT: float _d_x = 1;
//CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: B _d_this_obj;
Expand All @@ -224,7 +227,7 @@ int main () {
//CHECK-NEXT: return _d_x - _d_y;
//CHECK-NEXT: }
auto vm1_darg1_B = clad::differentiate((float(B::*)(float,float))&B::vm1, 1);
//CHECK: float vm1_darg1(float x, float y) override {
//CHECK: float vm1_darg1(float x, float y) {
//CHECK-NEXT: float _d_x = 0;
//CHECK-NEXT: float _d_y = 1;
// CHECK-NEXT: B _d_this_obj;
Expand Down Expand Up @@ -255,15 +258,15 @@ int main () {
printf("---\n"); // CHECK-EXEC: ---

auto vm_darg0_B1 = clad::differentiate(&B1::vm, 0);
//CHECK: float vm_darg0(float x, float y) override {
//CHECK: float B1::vm_darg0(float x, float y) {
//CHECK-NEXT: float _d_x = 1;
//CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: B1 _d_this_obj;
// CHECK-NEXT: B1 *_d_this = &_d_this_obj;
//CHECK-NEXT: return _d_x * y + x * _d_y + _d_x * y + x * _d_y;
//CHECK-NEXT: }
auto vm_darg1_B1 = clad::differentiate(&B1::vm, 1);
//CHECK: float vm_darg1(float x, float y) override {
//CHECK: float B1::vm_darg1(float x, float y) {
//CHECK-NEXT: float _d_x = 0;
//CHECK-NEXT: float _d_y = 1;
// CHECK-NEXT: B1 _d_this_obj;
Expand Down
2 changes: 1 addition & 1 deletion test/Functors/Simple.C
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public:
float operator_call_darg0(float x, float y);
};

// CHECK: float operator_call_darg0(float x, float y) {
// CHECK: float SimpleExpression::operator_call_darg0(float x, float y) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: SimpleExpression _d_this_obj;
Expand Down
2 changes: 1 addition & 1 deletion test/Gradient/Gradients.C
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ struct S {
return c1 * x + c2 * y;
}

//CHECK: void f_grad(double x, double y, S *_d_this, double *_d_x, double *_d_y) {
//CHECK: void S::f_grad(double x, double y, S *_d_this, double *_d_x, double *_d_y) {
//CHECK-NEXT: {
//CHECK-NEXT: (*_d_this).c1 += 1 * x;
//CHECK-NEXT: *_d_x += this->c1 * 1;
Expand Down
Loading

0 comments on commit 584eb31

Please sign in to comment.