Skip to content

Commit

Permalink
Fix for Parenthesis issue in gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
toshit3q34 committed Dec 25, 2024
1 parent 3e50707 commit f086371
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
7 changes: 6 additions & 1 deletion include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,12 @@ namespace clad {
/// either LHS or RHS is null.
clang::Expr* BuildOp(clang::BinaryOperatorKind OpCode, clang::Expr* L,
clang::Expr* R, clang::SourceLocation OpLoc = noLoc);

/// Function to resolve Unary Minus. If the leftmost operand
/// has a Unary Minus then adds parens before adding the unary minus.
/// \param[in] E Expression fed to the recursive call.
/// \param[in] OpLoc Location to add Unary Minus if needed.
/// \returns Expression with correct Unary Operator placement.
clang::Expr* ResolveUnaryMinus(clang::Expr* E, clang::SourceLocation OpLoc);
clang::Expr* BuildParens(clang::Expr* E);
/// Builds variable declaration to be used inside the derivative
/// body.
Expand Down
18 changes: 18 additions & 0 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "clang/Sema/Template.h"

#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"

#include <algorithm>
#include <numeric>
Expand Down Expand Up @@ -400,8 +401,25 @@ namespace clad {
// Debug clang requires the location to be valid
if (!OpLoc.isValid())
OpLoc = utils::GetValidSLoc(m_Sema);
// Call function for UnaryMinus
if (OpCode == UO_Minus)
return ResolveUnaryMinus(E->IgnoreCasts(), OpLoc);
return m_Sema.BuildUnaryOp(nullptr, OpLoc, OpCode, E).get();
}
Expr* VisitorBase::ResolveUnaryMinus(Expr* E, SourceLocation OpLoc) {
if (auto* UO = llvm::dyn_cast<clang::UnaryOperator>(E)) {
if (UO->getOpcode() == clang::UO_Minus)
return (UO->getSubExpr())->IgnoreParens();

Check warning on line 412 in lib/Differentiator/VisitorBase.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/VisitorBase.cpp#L411-L412

Added lines #L411 - L412 were not covered by tests
}
Expr* E_LHS = E;
while (auto* BO = llvm::dyn_cast<BinaryOperator>(E_LHS))
E_LHS = BO->getLHS();
if (auto* UO = llvm::dyn_cast<clang::UnaryOperator>(E_LHS->IgnoreCasts())) {
if (UO->getOpcode() == clang::UO_Minus)
E = m_Sema.ActOnParenExpr(E->getBeginLoc(), E->getEndLoc(), E).get();

Check warning on line 419 in lib/Differentiator/VisitorBase.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/VisitorBase.cpp#L418-L419

Added lines #L418 - L419 were not covered by tests
}
return m_Sema.BuildUnaryOp(nullptr, OpLoc, clang::UO_Minus, E).get();
}

Expr* VisitorBase::BuildOp(clang::BinaryOperatorKind OpCode, Expr* L, Expr* R,
SourceLocation OpLoc) {
Expand Down

0 comments on commit f086371

Please sign in to comment.