Skip to content

Commit

Permalink
Simplify error estimation by not storing m_ParamTypes
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Jan 1, 2025
1 parent 693bff0 commit 825396d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
1 change: 0 additions & 1 deletion include/clad/Differentiator/ErrorEstimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class ErrorEstimationHandler : public ExternalRMVSource {

std::stack<bool> m_ShouldEmit;
ReverseModeVisitor* m_RMV;
llvm::SmallVectorImpl<clang::QualType>* m_ParamTypes = nullptr;
llvm::SmallVectorImpl<clang::ParmVarDecl*>* m_Params = nullptr;

public:
Expand Down
16 changes: 7 additions & 9 deletions lib/Differentiator/ErrorEstimator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,25 +281,23 @@ void ErrorEstimationHandler::ActBeforeCreatingDerivedFnParamTypes(

void ErrorEstimationHandler::ActAfterCreatingDerivedFnParamTypes(
llvm::SmallVectorImpl<QualType>& paramTypes) {
m_ParamTypes = &paramTypes;
// If we are performing error estimation, our gradient function
// will have an extra argument which will hold the final error value
paramTypes.push_back(
m_RMV->m_Context.getLValueReferenceType(m_RMV->m_Context.DoubleTy));
ASTContext& C = m_RMV->m_Context;
paramTypes.push_back(C.getLValueReferenceType(C.DoubleTy));
}

void ErrorEstimationHandler::ActAfterCreatingDerivedFnParams(
llvm::SmallVectorImpl<ParmVarDecl*>& params) {
m_Params = &params;
// If in error estimation mode, create the error parameter
ASTContext& context = m_RMV->m_Context;
ASTContext& C = m_RMV->m_Context;
// Repeat the above but for the error ouput var "_final_error"
QualType LastParamTy = C.getLValueReferenceType(C.DoubleTy);
ParmVarDecl* errorVarDecl = ParmVarDecl::Create(
context, m_RMV->m_Derivative, noLoc, noLoc,
&context.Idents.get("_final_error"), m_ParamTypes->back(),
context.getTrivialTypeSourceInfo(m_ParamTypes->back(), noLoc),
params.front()->getStorageClass(),
/*DefArg=*/nullptr);
C, m_RMV->m_Derivative, noLoc, noLoc, &C.Idents.get("_final_error"),
LastParamTy, C.getTrivialTypeSourceInfo(LastParamTy, noLoc),
params.front()->getStorageClass(), /*DefArg=*/nullptr);
params.push_back(errorVarDecl);
m_RMV->m_Sema.PushOnScopeChains(params.back(), m_RMV->getCurrentScope(),
/*AddToContext=*/false);
Expand Down

0 comments on commit 825396d

Please sign in to comment.