diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index d8d4b0cda..5d67f334e 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2516,9 +2516,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, VarDeclDiff ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { StmtDiff initDiff; Expr* VDDerivedInit = nullptr; - // FIXME: find a more reliable way to determine if the declaration - // is in the function global scope. - bool isInFunctionGlobalScope = m_Reverse.size() <= 2; + // We take the parent of the current scope because the main compound + // statement of the function has its own scope as well. + bool isInFunctionGlobalScope = getCurrentScope()->getParent()==m_DerivativeFnScope; auto VDDerivedType = ComputeAdjointType(VD->getType()); auto VDCloneType = CloneType(VD->getType()); if (!isInFunctionGlobalScope) @@ -2725,9 +2725,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SmallVector declsDiff; // Need to put array decls inlined. llvm::SmallVector localDeclsDiff; - // FIXME: find a more reliable way to determine if the declaration - // is in the function global scope. - bool isInFunctionGlobalScope = m_Reverse.size() <= 2; + // We take the parent of the current scope because the main compound + // statement of the function has its own scope as well. + bool isInFunctionGlobalScope = getCurrentScope()->getParent()==m_DerivativeFnScope; // For each variable declaration v, create another declaration _d_v to // store derivatives for potential reassignments. E.g. // double y = x;