Skip to content

Commit

Permalink
Clean up the code computing the parameters and their types.
Browse files Browse the repository at this point in the history
This patch harmonizes the steps to build the derived function prototype along
with its parameters. It refactors a lot of tricky code that grew organically
and included many workarounds. This is only a single step in a direction of
merging the common code patterns across visitors.

As a bonus we do not generate broken code when we have derivative parameters
named `_d_x`. That's common in computing hessians.
  • Loading branch information
vgvassilev committed Jan 4, 2025
1 parent 825396d commit 6329ae8
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 175 deletions.
18 changes: 7 additions & 11 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
#include <stack>
#include <unordered_map>

namespace llvm {
template <typename T> class SmallVectorImpl;
}

namespace clad {
class ErrorEstimationHandler;
class ExternalRMVSource;
Expand Down Expand Up @@ -682,19 +686,11 @@ namespace clad {
///\paramp[in] source An external RMV source
void AddExternalSource(ExternalRMVSource& source);

/// Computes and returns the sequence of derived function parameter types.
///
/// Information about the original function and the differentiation mode
/// are taken from the data member variables.
llvm::SmallVector<clang::QualType, 8> ComputeParamTypes(const DiffParams& diffParams);
/// Computes and returns the derived function prototype.
clang::QualType ComputeDerivativeFunctionType();

/// Builds and returns the sequence of derived function parameters.
///
/// Information about the original function, derived function, derived
/// function parameter types and the differentiation mode are implicitly
/// taken from the data member variables.
llvm::SmallVector<clang::ParmVarDecl*, 8>
BuildParams(DiffParams& diffParams);
void BuildParams(llvm::SmallVectorImpl<clang::ParmVarDecl*>& params);

clang::QualType ComputeAdjointType(clang::QualType T);
clang::QualType ComputeParamType(clang::QualType T);
Expand Down
Loading

0 comments on commit 6329ae8

Please sign in to comment.