Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move the computation of parameters in SetupDerivativeParameters. #1191

Merged
merged 1 commit into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/Sema.h"

#include "llvm/ADT/SmallVector.h"

#include <array>
#include <stack>
#include <unordered_map>
Expand Down Expand Up @@ -37,9 +39,6 @@ class BaseForwardModeVisitor

DerivativeAndOverload DerivePushforward();

/// Computes the return type of the derivative in `m_DiffReq->Function`.
clang::QualType ComputeDerivativeFunctionType();

virtual void ExecuteInsidePushforwardFunctionBlock();

static bool IsDifferentiableType(clang::QualType T);
Expand Down Expand Up @@ -148,6 +147,14 @@ class BaseForwardModeVisitor
const clang::CXXConstructExpr* CE,
llvm::SmallVectorImpl<clang::Expr*>& clonedArgs,
llvm::SmallVectorImpl<clang::Expr*>& derivedArgs);

private:
/// Computes the return type of the derivative in `m_DiffReq->Function`.
clang::QualType ComputeDerivativeFunctionType();

/// Prepares the derivative function parameters.
void
SetupDerivativeParameters(llvm::SmallVectorImpl<clang::ParmVarDecl*>& params);
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
};
} // end namespace clad

Expand Down
132 changes: 64 additions & 68 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,32 +154,16 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
FunctionDecl* derivedFD = result.first;
m_Derivative = derivedFD;

llvm::SmallVector<ParmVarDecl*, 4> params;
const ParmVarDecl* PVD = nullptr;

// Function declaration scope
beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope |
Scope::DeclScope);
m_Sema.PushFunctionScope();
m_Sema.PushDeclContext(getCurrentScope(), m_Derivative);

for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) {
PVD = FD->getParamDecl(i);
auto* newPVD = CloneParmVarDecl(PVD, PVD->getIdentifier(),
/*pushOnScopeChains=*/true,
/*cloneDefaultArg=*/false);

// Make m_IndependentVar to point to the argument of the newly created
// derivedFD.
if (PVD == m_IndependentVar)
m_IndependentVar = newPVD;

params.push_back(newPVD);
}
llvm::SmallVector<ParmVarDecl*, 16> params;
SetupDerivativeParameters(params);
derivedFD->setParams(params);

llvm::ArrayRef<ParmVarDecl*> paramsRef =
clad_compat::makeArrayRef(params.data(), params.size());
derivedFD->setParams(paramsRef);
derivedFD->setBody(nullptr);

if (!m_DiffReq.DeclarationOnly) {
Expand Down Expand Up @@ -386,6 +370,64 @@ QualType BaseForwardModeVisitor::ComputeDerivativeFunctionType() {
return m_Context.getFunctionType(dRetTy, FnTypes, EPI);
}

void BaseForwardModeVisitor::SetupDerivativeParameters(
llvm::SmallVectorImpl<ParmVarDecl*>& params) {
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
const FunctionDecl* FD = m_DiffReq.Function;
for (ParmVarDecl* PVD : FD->parameters()) {
IdentifierInfo* PVDII = PVD->getIdentifier();
// Implicitly created special member functions have no parameter names.
if (!PVD->getDeclName())
PVDII = CreateUniqueIdentifier("param");

auto* newPVD = CloneParmVarDecl(PVD, PVDII,
/*pushOnScopeChains=*/true,
/*cloneDefaultArg=*/false);

// Point m_IndependentVar to the argument of the newly created param.
if (PVD == m_IndependentVar)
m_IndependentVar = newPVD;

if (!PVD->getDeclName()) // We can't use lookup-based replacements
m_DeclReplacements[PVD] = newPVD;

params.push_back(newPVD);
}

if (m_DiffReq.Mode == DiffMode::forward)
return;

bool HasThis = false;
// If we are differentiating an instance member function then create a
// parameter for representing derivative of `this` pointer with respect to the
// independent parameter.
if (const auto* MD = dyn_cast<CXXMethodDecl>(FD)) {
if (MD->isInstance()) {
IdentifierInfo* dThisII = CreateUniqueIdentifier("_d_this");
auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Sema.CurContext, dThisII,
MD->getThisType());
m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), /*AddToContext=*/false);
params.push_back(dPVD);
// FIXME: Replace m_ThisExprDerivative in favor of lookups of _d_this.
m_ThisExprDerivative = BuildDeclRef(dPVD);
HasThis = true;
}
}

for (size_t i = 0, e = params.size() - HasThis; i < e; ++i) {
const ParmVarDecl* PVD = params[i];

if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType()))
continue;

IdentifierInfo* II = CreateUniqueIdentifier("_d_" + PVD->getNameAsString());
auto* dPVD = utils::BuildParmVarDecl(
m_Sema, m_Derivative, II, GetPushForwardDerivativeType(PVD->getType()),
PVD->getStorageClass());
params.push_back(dPVD);
m_Variables[PVD] = BuildDeclRef(dPVD);
}
}

void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() {
Stmt* bodyDiff = Visit(m_DiffReq->getBody()).getStmt();
auto* CS = cast<CompoundStmt>(bodyDiff);
Expand Down Expand Up @@ -417,61 +459,15 @@ DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() {
m_Builder.cloneFunction(FD, *this, DC, loc, derivedFnName, derivedFnType);
m_Derivative = cloneFunctionResult.first;

llvm::SmallVector<ParmVarDecl*, 16> params;
llvm::SmallVector<ParmVarDecl*, 16> derivedParams;
beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope |
Scope::DeclScope);
m_Sema.PushFunctionScope();
m_Sema.PushDeclContext(getCurrentScope(), m_Derivative);

// If we are differentiating an instance member function then
// create a parameter for representing derivative of
// `this` pointer with respect to the independent parameter.
if (const auto* MFD = dyn_cast<CXXMethodDecl>(FD)) {
if (MFD->isInstance()) {
auto thisType = MFD->getThisType();
IdentifierInfo* derivedPVDII = CreateUniqueIdentifier("_d_this");
auto* derivedPVD = utils::BuildParmVarDecl(m_Sema, m_Sema.CurContext,
derivedPVDII, thisType);
m_Sema.PushOnScopeChains(derivedPVD, getCurrentScope(),
/*AddToContext=*/false);
derivedParams.push_back(derivedPVD);
m_ThisExprDerivative = BuildDeclRef(derivedPVD);
}
}

std::size_t numParamsOriginalFn = m_DiffReq->getNumParams();
for (std::size_t i = 0; i < numParamsOriginalFn; ++i) {
const auto* PVD = m_DiffReq->getParamDecl(i);
// Some of the special member functions created implicitly by compilers
// have missing parameter identifier.
bool identifierMissing = false;
IdentifierInfo* PVDII = PVD->getIdentifier();
if (!PVDII || PVDII->getLength() == 0) {
PVDII = CreateUniqueIdentifier("param");
identifierMissing = true;
}
auto* newPVD = CloneParmVarDecl(PVD, PVDII,
/*pushOnScopeChains=*/true,
/*cloneDefaultArg=*/false);
params.push_back(newPVD);

if (identifierMissing)
m_DeclReplacements[PVD] = newPVD;

if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType()))
continue;
auto derivedPVDName = "_d_" + std::string(PVDII->getName());
IdentifierInfo* derivedPVDII = CreateUniqueIdentifier(derivedPVDName);
auto* derivedPVD = utils::BuildParmVarDecl(
m_Sema, m_Derivative, derivedPVDII,
GetPushForwardDerivativeType(PVD->getType()), PVD->getStorageClass());
derivedParams.push_back(derivedPVD);
m_Variables[newPVD] = BuildDeclRef(derivedPVD);
}

params.insert(params.end(), derivedParams.begin(), derivedParams.end());
llvm::SmallVector<ParmVarDecl*, 16> params;
SetupDerivativeParameters(params);
m_Derivative->setParams(params);

m_Derivative->setBody(nullptr);

if (!m_DiffReq.DeclarationOnly) {
Expand Down
Loading