Skip to content

Commit

Permalink
Reroute pushforwards to the generic ::Derive method.
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Dec 27, 2024
1 parent c893c31 commit aa5cbdf
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 141 deletions.
4 changes: 1 addition & 3 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ class BaseForwardModeVisitor
///
DerivativeAndOverload Derive();

DerivativeAndOverload DerivePushforward();

virtual void ExecuteInsidePushforwardFunctionBlock();
virtual void ExecuteInsidePushforwardFunctionBlock() {}

static bool IsDifferentiableType(clang::QualType T);

Expand Down
213 changes: 77 additions & 136 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "clad/Differentiator/DiffMode.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/ParseDiffArgsTypes.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/ASTLambda.h"
Expand All @@ -28,10 +29,13 @@
#include "clang/Sema/SemaInternal.h"
#include "clang/Sema/Template.h"

#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/Constants.h"
#include "llvm/Support/SaveAndRestore.h"

#include <algorithm>
#include <cassert>
#include <string>
#include <vector>

Expand Down Expand Up @@ -67,75 +71,80 @@ bool IsRealNonReferenceType(QualType T) {

DerivativeAndOverload BaseForwardModeVisitor::Derive() {
const FunctionDecl* FD = m_DiffReq.Function;
assert(m_DiffReq.Mode == DiffMode::forward);
assert(m_DiffReq.Mode == DiffMode::forward ||
m_DiffReq.Mode == DiffMode::experimental_pushforward ||
m_DiffReq.Mode == DiffMode::experimental_vector_pushforward);
assert(!m_DerivativeInFlight &&
"Doesn't support recursive diff. Use DiffPlan.");
m_DerivativeInFlight = true;

DiffInputVarsInfo DVI = m_DiffReq.DVI;

// FIXME: Shouldn't we give error here that no arg is specified?
if (DVI.empty())
return {};

DiffInputVarInfo diffVarInfo = DVI.back();

// Check that only one arg is requested and if the arg requested is of array
// or pointer type, only one of the indices have been requested
if (DVI.size() > 1 || (isArrayOrPointerType(diffVarInfo.param->getType()) &&
(diffVarInfo.paramIndexInterval.size() != 1))) {
diag(DiagnosticsEngine::Error,
m_DiffReq.Args ? m_DiffReq.Args->getEndLoc() : noLoc,
"Forward mode differentiation w.r.t. several parameters at once is "
"not "
"supported, call 'clad::differentiate' for each parameter "
"separately");
return {};
}

// FIXME: implement gradient-vector products to fix the issue.
assert((DVI.size() == 1) &&
"nested forward mode differentiation for several args is broken");

// FIXME: Differentiation variable cannot always be represented just by
// `ValueDecl*` variable. For example -- `u.mem1.mem2,`, `arr[7]` etc.
// FIXME: independent variable is misleading terminology, what we actually
// mean here is 'variable' with respect to which differentiation is being
// performed. Mathematically, independent variables are all the function
// parameters, thus, does not convey the intendend meaning.
m_IndependentVar = DVI.back().param;
// If param is not real (i.e. floating point or integral), a pointer to a
// real type, or an array of a real type we cannot differentiate it.
// FIXME: we should support custom numeric types in the future.
if (isArrayOrPointerType(m_IndependentVar->getType())) {
if (!m_IndependentVar->getType()
->getPointeeOrArrayElementType()
->isRealType()) {
diag(DiagnosticsEngine::Error, m_IndependentVar->getEndLoc(),
"attempted differentiation w.r.t. a parameter ('%0') which is not"
" an array or pointer of a real type",
{m_IndependentVar->getNameAsString()});
llvm::SaveAndRestore<bool> saveInFlight(m_DerivativeInFlight,
/*NewValue=*/true);

if (m_DiffReq.Mode == DiffMode::forward) {
DiffInputVarsInfo DVI = m_DiffReq.DVI;

// FIXME: Shouldn't we give error here that no arg is specified?
if (DVI.empty())
return {};
}
m_IndependentVarIndex = diffVarInfo.paramIndexInterval.Start;
} else {
QualType T = m_IndependentVar->getType();
bool isField = false;
if (auto RD = diffVarInfo.param->getType()->getAsCXXRecordDecl()) {
llvm::SmallVector<llvm::StringRef, 4> ref(diffVarInfo.fields.begin(),
diffVarInfo.fields.end());
T = utils::ComputeMemExprPathType(m_Sema, RD, ref);
isField = true;
}
if (!IsRealNonReferenceType(T)) {
diag(DiagnosticsEngine::Error, m_DiffReq.Args->getEndLoc(),
"Attempted differentiation w.r.t. %0 '%1' which is not "
"of real type.",
{(isField ? "member" : "parameter"), diffVarInfo.source});

DiffInputVarInfo diffVarInfo = DVI.back();

// Check that only one arg is requested and if the arg requested is of array
// or pointer type, only one of the indices have been requested
if (DVI.size() > 1 || (isArrayOrPointerType(diffVarInfo.param->getType()) &&
(diffVarInfo.paramIndexInterval.size() != 1))) {
diag(DiagnosticsEngine::Error,
m_DiffReq.Args ? m_DiffReq.Args->getEndLoc() : noLoc,
"Forward mode differentiation w.r.t. several parameters at once is "
"not "
"supported, call 'clad::differentiate' for each parameter "
"separately");
return {};
}
}

// FIXME: implement gradient-vector products to fix the issue.
assert((DVI.size() == 1) &&
"nested forward mode differentiation for several args is broken");

// FIXME: Differentiation variable cannot always be represented just by
// `ValueDecl*` variable. For example -- `u.mem1.mem2,`, `arr[7]` etc.
// FIXME: independent variable is misleading terminology, what we actually
// mean here is 'variable' with respect to which differentiation is being
// performed. Mathematically, independent variables are all the function
// parameters, thus, does not convey the intendend meaning.
m_IndependentVar = DVI.back().param;
// If param is not real (i.e. floating point or integral), a pointer to a
// real type, or an array of a real type we cannot differentiate it.
// FIXME: we should support custom numeric types in the future.
if (isArrayOrPointerType(m_IndependentVar->getType())) {
if (!m_IndependentVar->getType()
->getPointeeOrArrayElementType()
->isRealType()) {
diag(DiagnosticsEngine::Error, m_IndependentVar->getEndLoc(),
"attempted differentiation w.r.t. a parameter ('%0') which is not"
" an array or pointer of a real type",
{m_IndependentVar->getNameAsString()});
return {};
}
m_IndependentVarIndex = diffVarInfo.paramIndexInterval.Start;
} else {
QualType T = m_IndependentVar->getType();
bool isField = false;
if (auto* RD = diffVarInfo.param->getType()->getAsCXXRecordDecl()) {
llvm::SmallVector<llvm::StringRef, 4> ref(diffVarInfo.fields.begin(),
diffVarInfo.fields.end());
T = utils::ComputeMemExprPathType(m_Sema, RD, ref);
isField = true;
}
if (!IsRealNonReferenceType(T)) {
diag(DiagnosticsEngine::Error, m_DiffReq.Args->getEndLoc(),
"Attempted differentiation w.r.t. %0 '%1' which is not "
"of real type.",
{(isField ? "member" : "parameter"), diffVarInfo.source});
return {};
}
}
}
// Check if the function is already declared as a custom derivative.
std::string gradientName = m_DiffReq.ComputeDerivativeName();

Expand Down Expand Up @@ -177,6 +186,10 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
m_DerivativeFnScope = getCurrentScope();
beginBlock();

// FIXME: Remove the override in VectorPushForwardModeVisitor.
if (m_DiffReq.Mode == DiffMode::experimental_vector_pushforward)
ExecuteInsidePushforwardFunctionBlock();

if (m_DiffReq.Mode == DiffMode::forward)
GenerateSeeds(derivedFD);

Expand All @@ -202,8 +215,6 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
m_Sema.PopDeclContext();
endScope(); // Function decl scope

m_DerivativeInFlight = false;

return DerivativeAndOverload{result.first,
/*OverloadFunctionDecl=*/nullptr};
}
Expand Down Expand Up @@ -285,7 +296,7 @@ void BaseForwardModeVisitor::SetupDerivativeParameters(
// independent parameter.
if (const auto* MD = dyn_cast<CXXMethodDecl>(FD)) {
if (MD->isInstance()) {
IdentifierInfo* dThisII = CreateUniqueIdentifier("_d_this");
IdentifierInfo* dThisII = &m_Context.Idents.get("_d_this");
auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Sema.CurContext, dThisII,
MD->getThisType());
m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), /*AddToContext=*/false);
Expand All @@ -302,7 +313,7 @@ void BaseForwardModeVisitor::SetupDerivativeParameters(
if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType()))
continue;

IdentifierInfo* II = CreateUniqueIdentifier("_d_" + PVD->getNameAsString());
IdentifierInfo* II = &m_Context.Idents.get("_d_" + PVD->getNameAsString());
auto* dPVD = utils::BuildParmVarDecl(
m_Sema, m_Derivative, II, GetPushForwardDerivativeType(PVD->getType()),
PVD->getStorageClass());
Expand Down Expand Up @@ -437,76 +448,6 @@ void BaseForwardModeVisitor::GenerateSeeds(const clang::FunctionDecl* dFD) {
}
}

void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() {
Stmt* bodyDiff = Visit(m_DiffReq->getBody()).getStmt();
auto* CS = cast<CompoundStmt>(bodyDiff);
for (Stmt* S : CS->body())
addToCurrentBlock(S);
}

DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() {
const FunctionDecl* FD = m_DiffReq.Function;
assert(m_DiffReq.Mode == GetPushForwardMode());
assert(!m_DerivativeInFlight &&
"Doesn't support recursive diff. Use DiffPlan.");
m_DerivativeInFlight = true;

llvm::SaveAndRestore<DeclContext*> saveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> saveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());

IdentifierInfo* II = &m_Context.Idents.get(m_DiffReq.ComputeDerivativeName());
SourceLocation loc{FD->getLocation()};
DeclarationNameInfo derivedFnName(II, loc);

// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
m_Sema.CurContext = DC;
QualType derivedFnType = ComputeDerivativeFunctionType();
DeclWithContext cloneFunctionResult =
m_Builder.cloneFunction(FD, *this, DC, loc, derivedFnName, derivedFnType);
m_Derivative = cloneFunctionResult.first;

beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope |
Scope::DeclScope);
m_Sema.PushFunctionScope();
m_Sema.PushDeclContext(getCurrentScope(), m_Derivative);

llvm::SmallVector<ParmVarDecl*, 16> params;
SetupDerivativeParameters(params);
m_Derivative->setParams(params);

m_Derivative->setBody(nullptr);

if (!m_DiffReq.DeclarationOnly) {
beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
beginBlock();

// execute the functor inside the function body.
ExecuteInsidePushforwardFunctionBlock();

Stmt* derivativeBody = endBlock();
m_Derivative->setBody(derivativeBody);
endScope(); // Function body scope

// Size >= current derivative order means that there exists a declaration
// or prototype for the currently derived function.
if (m_DiffReq.DerivedFDPrototypes.size() >=
m_DiffReq.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
m_DiffReq.DerivedFDPrototypes[m_DiffReq.CurrentDerivativeOrder - 1]);
}

m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope(); // Function decl scope

m_DerivativeInFlight = false;
return DerivativeAndOverload{cloneFunctionResult.first};
}

StmtDiff BaseForwardModeVisitor::VisitStmt(const Stmt* S) {
diag(DiagnosticsEngine::Warning, S->getBeginLoc(),
"attempted to differentiate unsupported statement, no changes applied");
Expand Down
4 changes: 2 additions & 2 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,13 +431,13 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
result = V.Derive();
} else if (request.Mode == DiffMode::experimental_pushforward) {
PushForwardModeVisitor V(*this, request);
result = V.DerivePushforward();
result = V.Derive();
} else if (request.Mode == DiffMode::vector_forward_mode) {
VectorForwardModeVisitor V(*this, request);
result = V.DeriveVectorMode();
} else if (request.Mode == DiffMode::experimental_vector_pushforward) {
VectorPushForwardModeVisitor V(*this, request);
result = V.DerivePushforward();
result = V.Derive();
} else if (request.Mode == DiffMode::reverse) {
ReverseModeVisitor V(*this, request);
result = V.Derive();
Expand Down

0 comments on commit aa5cbdf

Please sign in to comment.