Skip to content

Commit

Permalink
Do not remove const in utils::GetValueType
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Feb 19, 2025
1 parent 403c348 commit 51670da
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 21 deletions.
3 changes: 3 additions & 0 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ namespace clad {
/// is neither an array nor a pointer type, then simply returns `T`.
clang::QualType GetValueType(clang::QualType T);

/// Returns the same type as GetValueType but without const qualifier.
clang::QualType GetNonConstValueType(clang::QualType T);

/// Builds and returns the init expression to initialise `clad::array` and
/// `clad::array_ref` from a constant array.
///
Expand Down
18 changes: 12 additions & 6 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/Compatibility.h"

#include "ConstantFolder.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/Decl.h"
#include "clang/AST/Expr.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/AST/Type.h"
#include "clang/Basic/Builtins.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Sema/Lookup.h"
#include "llvm/ADT/SmallVector.h"
#include <clang/AST/DeclCXX.h>
#include "clad/Differentiator/Compatibility.h"

#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"

using namespace clang;
namespace clad {
Expand Down Expand Up @@ -380,14 +383,17 @@ namespace clad {
valueType = T->getPointeeType();
else if (T->isReferenceType())
valueType = T.getNonReferenceType();
// FIXME: `QualType::getPointeeOrArrayElementType` loses type qualifiers.
else if (T->isArrayType())
valueType =
T->getPointeeOrArrayElementType()->getCanonicalTypeInternal();
else if (const auto* AT = dyn_cast<clang::ArrayType>(T))
valueType = AT->getElementType();
else if (T->isEnumeralType()) {
if (const auto* ET = dyn_cast<EnumType>(T))
valueType = ET->getDecl()->getIntegerType();
}
return valueType;
}

clang::QualType GetNonConstValueType(clang::QualType T) {
QualType valueType = GetValueType(T);
valueType.removeLocalConst();
return valueType;
}
Expand Down
8 changes: 2 additions & 6 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4244,10 +4244,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
}

QualType ReverseModeVisitor::GetParameterDerivativeType(QualType Type) {

QualType ValueType = utils::GetValueType(Type);
// derivative variables should always be of non-const type.
ValueType.removeLocalConst();
QualType ValueType = utils::GetNonConstValueType(Type);
QualType nonRefValueType = ValueType.getNonReferenceType();
return m_Context.getPointerType(nonRefValueType);
}
Expand All @@ -4266,8 +4263,7 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {

clang::QualType ReverseModeVisitor::ComputeAdjointType(clang::QualType T) {
if (T->isReferenceType()) {
QualType TValueType = utils::GetValueType(T);
TValueType.removeLocalConst();
QualType TValueType = utils::GetNonConstValueType(T);
return m_Context.getPointerType(TValueType);
}
T.removeLocalConst();
Expand Down
18 changes: 9 additions & 9 deletions lib/Differentiator/VectorForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ QualType
VectorForwardModeVisitor::GetPushForwardDerivativeType(QualType ParamType) {
if (ParamType == m_Context.VoidTy)
return ParamType;
QualType valueType = utils::GetValueType(ParamType);
QualType valueType = utils::GetNonConstValueType(ParamType);
QualType resType;
if (utils::isArrayOrPointerType(ParamType)) {
// If the parameter is a pointer or an array, then the derivative will be a
Expand Down Expand Up @@ -75,8 +75,7 @@ DerivativeAndOverload VectorForwardModeVisitor::DeriveVectorMode() {
if (it == std::end(args))
continue; // This parameter is not in the diff list.

QualType valueType = utils::GetValueType(PVD->getType());
valueType.removeLocalConst();
QualType valueType = utils::GetNonConstValueType(PVD->getType());
QualType dParamType;
if (utils::isArrayOrPointerType(PVD->getType())) {
// Generate array reference type for the derivative.
Expand Down Expand Up @@ -160,7 +159,7 @@ DerivativeAndOverload VectorForwardModeVisitor::DeriveVectorMode() {
bool is_array =
utils::isArrayOrPointerType(m_DiffReq->getParamDecl(i)->getType());
auto param = params[i];
QualType dParamType = clad::utils::GetValueType(param->getType());
QualType dParamType = clad::utils::GetNonConstValueType(param->getType());

Expr* dVectorParam = nullptr;
if (m_IndependentVars.size() > independentVarIndex &&
Expand Down Expand Up @@ -503,7 +502,8 @@ StmtDiff VectorForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) {
Expr* derivedRetValE = retValDiff.getExpr_dx();
// If we are in vector mode, we need to wrap the return value in a
// vector.
QualType cladArrayType = GetCladArrayOfType(utils::GetValueType(retType));
QualType cladArrayType =
GetCladArrayOfType(utils::GetNonConstValueType(retType));
VarDecl* dVectorParamDecl = BuildVarDecl(cladArrayType, "_d_vector_return",
derivedRetValE, /*DirectInit=*/true);
// Create an array of statements to hold the return statement and the
Expand Down Expand Up @@ -581,10 +581,10 @@ VectorForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
// This may not necessarily be true in the future.
VarDecl* VDClone = BuildVarDecl(VD->getType(), VD->getNameAsString(),
initDiff.getExpr(), VD->isDirectInit());
VarDecl* VDDerived =
BuildVarDecl(GetCladArrayOfType(utils::GetValueType(VD->getType())),
"_d_vector_" + VD->getNameAsString(), initDiff.getExpr_dx(),
/*DirectInit=*/true);
VarDecl* VDDerived = BuildVarDecl(
GetCladArrayOfType(utils::GetNonConstValueType(VD->getType())),
"_d_vector_" + VD->getNameAsString(), initDiff.getExpr_dx(),
/*DirectInit=*/true);

m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
return DeclDiff<VarDecl>(VDClone, VDDerived);
Expand Down

0 comments on commit 51670da

Please sign in to comment.