Skip to content

Commit

Permalink
Unify ArrMap and ObjMap.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Nov 3, 2023
1 parent 3d34780 commit c9ba780
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 123 deletions.
100 changes: 26 additions & 74 deletions include/clad/Differentiator/TBRAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,69 +16,27 @@ namespace clad {

class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
private:
/// Used to provide a hash function for an unordered_map with llvm::APInt
/// type keys.
struct APIntHash {
size_t operator()(const llvm::APInt& x) const {
return llvm::hash_value(x);
}
};

static bool eqAPInt(const llvm::APInt& x, const llvm::APInt& y) {
if (x.getBitWidth() != y.getBitWidth())
return false;
return x == y;
/// ProfileID is the key type for ArrMap used to represent array indices
/// and object fields.
using ProfileID = clad_compat::FoldingSetNodeID;

ProfileID getProfileID(const Expr* E) const{
ProfileID profID;
E->Profile(profID, m_Context, /* Canonical */ true);
return profID;
}

struct APIntComp {
bool operator()(const llvm::APInt& x, const llvm::APInt& y) const {
return eqAPInt(x, y);
}
};
static ProfileID getProfileID(const FieldDecl* FD) {
ProfileID profID;
profID.AddPointer(FD);
return profID;
}

/// Just a helper struct serving as a wrapper for IdxOrMemberValue union.
/// Used to unwrap expressions like a[6].x.t[3]. Only used in
/// TBRAnalyzer::overlay().
struct IdxOrMember {
enum IdxOrMemberType { FIELD, INDEX };
union IdxOrMemberValue {
const clang::FieldDecl* field;
llvm::APInt index;
IdxOrMemberValue() : field(nullptr) {}
~IdxOrMemberValue() {}
IdxOrMemberValue(const IdxOrMemberValue&) = delete;
IdxOrMemberValue& operator=(const IdxOrMemberValue&) = delete;
IdxOrMemberValue(const IdxOrMemberValue&&) = delete;
IdxOrMemberValue& operator=(const IdxOrMemberValue&&) = delete;
};
IdxOrMemberType type;
IdxOrMemberValue val;
IdxOrMember(const clang::FieldDecl* field) : type(IdxOrMemberType::FIELD) {
val.field = field;
struct ProfileIDHash {
size_t operator()(const ProfileID& x) const {
return x.ComputeHash();
}
IdxOrMember(llvm::APInt&& index) : type(IdxOrMemberType::INDEX) {
new (&val.index) llvm::APInt(index);
}
IdxOrMember(const IdxOrMember& other) {
new (&val.index) llvm::APInt();
*this = other;
}
IdxOrMember(const IdxOrMember&& other) noexcept {
new (&val.index) llvm::APInt();
*this = other;
}
IdxOrMember& operator=(const IdxOrMember& other) {
type = other.type;
if (type == IdxOrMemberType::FIELD)
val.field = other.val.field;
else
val.index = other.val.index;
return *this;
}
IdxOrMember& operator=(const IdxOrMember&& other) noexcept {
return *this = other;
}
~IdxOrMember() = default;
};

/// Stores all the necessary information about one variable. Fundamental type
Expand All @@ -93,17 +51,16 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// 'double& x = f(b);' is not supported.

struct VarData;
using ObjMap = std::unordered_map<const clang::FieldDecl*, VarData>;
using ArrMap =
std::unordered_map<const llvm::APInt, VarData, APIntHash, APIntComp>;
std::unordered_map<const ProfileID, VarData, ProfileIDHash>;

struct VarData {
enum VarDataType { UNDEFINED, FUND_TYPE, OBJ_TYPE, ARR_TYPE, REF_TYPE };
union VarDataValue {
bool m_FundData;
/// m_ObjData, m_ArrData are stored as pointers for VarDataValue to take
/// m_ArrData is stored as pointers for VarDataValue to take
/// less space.
ObjMap* m_ObjData;
/// Both arrays and and objects are modelled using m_ArrData;
ArrMap* m_ArrData;
Expr* m_RefData;
VarDataValue() : m_FundData(false) {}
Expand All @@ -117,12 +74,8 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
VarData(const QualType QT);

/// Erases all children VarData's of this VarData.
void erase() {
if (type == OBJ_TYPE) {
for (auto& pair : *val.m_ObjData)
pair.second.erase();
delete val.m_ObjData;
} else if (type == ARR_TYPE) {
void erase() const {
if (type == OBJ_TYPE || type == ARR_TYPE) {
for (auto& pair : *val.m_ArrData)
pair.second.erase();
delete val.m_ArrData;
Expand All @@ -138,7 +91,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// indices/members of the expression being overlaid and the index of of the
/// current index/member.
void overlay(VarData& targetData,
llvm::SmallVector<IdxOrMember, 2>& IdxAndMemberSequence,
llvm::SmallVector<ProfileID, 2>& IDSequence,
size_t i);
/// Returns true if there is at least one required to store node among
/// child nodes.
Expand Down Expand Up @@ -190,8 +143,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
std::unordered_map<const clang::VarDecl*, VarData>();
VarsData* prev = nullptr;

VarsData() {}
VarsData(VarsData& other) : data(other.data), prev(other.prev) {}
VarsData() = default;

~VarsData() {
for (auto& pair : data)
Expand Down Expand Up @@ -264,7 +216,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
std::vector<short> blockPassCounter;

/// ID of the CFG block being visited.
unsigned curBlockID;
unsigned curBlockID{};

/// The set of IDs of the CFG blocks that should be visited.
std::set<unsigned> CFGQueue;
Expand All @@ -290,7 +242,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {

//// Modes Setters
/// Sets the mode manually
void setMode(int mode) { modeStack.push_back(mode); }
void setMode(short mode) { modeStack.push_back(mode); }
/// Sets nonLinearMode but leaves markingMode just as it was.
void startNonLinearMode() {
modeStack.push_back(modeStack.back() | Mode::nonLinearMode);
Expand All @@ -310,7 +262,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {

/// Destructor
~TBRAnalyzer() {
for (auto varsData : blockData) {
for (auto* varsData : blockData) {
if (varsData) {
delete varsData;
}
Expand Down
84 changes: 35 additions & 49 deletions lib/Differentiator/TBRAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ namespace clad {
void TBRAnalyzer::setIsRequired(VarData& varData, bool isReq) {
if (varData.type == VarData::FUND_TYPE)
varData.val.m_FundData = isReq;
else if (varData.type == VarData::OBJ_TYPE)
for (auto& pair : *varData.val.m_ObjData)
setIsRequired(pair.second, isReq);
else if (varData.type == VarData::ARR_TYPE)
else if (varData.type == VarData::OBJ_TYPE || varData.type == VarData::ARR_TYPE)
for (auto& pair : *varData.val.m_ArrData)
setIsRequired(pair.second, isReq);
else if (varData.type == VarData::REF_TYPE && varData.val.m_RefData)
Expand All @@ -24,8 +21,8 @@ void TBRAnalyzer::merge(VarData& targetData, VarData& mergeData) {
targetData.val.m_FundData =
targetData.val.m_FundData || mergeData.val.m_FundData;
} else if (targetData.type == VarData::OBJ_TYPE) {
for (auto& pair : *targetData.val.m_ObjData)
merge(pair.second, (*mergeData.val.m_ObjData)[pair.first]);
for (auto& pair : *targetData.val.m_ArrData)
merge(pair.second, (*mergeData.val.m_ArrData)[pair.first]);

Check warning on line 25 in lib/Differentiator/TBRAnalyzer.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/TBRAnalyzer.cpp#L24-L25

Added lines #L24 - L25 were not covered by tests
} else if (targetData.type == VarData::ARR_TYPE) {
/// FIXME: Currently non-constant indices are not supported in merging.
for (auto& pair : *targetData.val.m_ArrData) {
Expand All @@ -49,11 +46,7 @@ TBRAnalyzer::VarData TBRAnalyzer::copy(VarData& copyData) {
res.type = copyData.type;
if (copyData.type == VarData::FUND_TYPE) {
res.val.m_FundData = copyData.val.m_FundData;
} else if (copyData.type == VarData::OBJ_TYPE) {
res.val.m_ObjData = new ObjMap();
for (auto& pair : *copyData.val.m_ObjData)
(*res.val.m_ObjData)[pair.first] = copy(pair.second);
} else if (copyData.type == VarData::ARR_TYPE) {
} else if (copyData.type == VarData::OBJ_TYPE || copyData.type == VarData::ARR_TYPE) {
res.val.m_ArrData = new ArrMap();
for (auto& pair : *copyData.val.m_ArrData)
(*res.val.m_ArrData)[pair.first] = copy(pair.second);
Expand All @@ -66,11 +59,7 @@ TBRAnalyzer::VarData TBRAnalyzer::copy(VarData& copyData) {
bool TBRAnalyzer::findReq(const VarData& varData) {
if (varData.type == VarData::FUND_TYPE)
return varData.val.m_FundData;
if (varData.type == VarData::OBJ_TYPE) {
for (auto& pair : *varData.val.m_ObjData)
if (findReq(pair.second))
return true;
} else if (varData.type == VarData::ARR_TYPE) {
if (varData.type == VarData::OBJ_TYPE || varData.type == VarData::ARR_TYPE) {
for (auto& pair : *varData.val.m_ArrData)
if (findReq(pair.second))
return true;
Expand All @@ -86,23 +75,20 @@ bool TBRAnalyzer::findReq(const VarData& varData) {

void TBRAnalyzer::overlay(
VarData& targetData,
llvm::SmallVector<IdxOrMember, 2>& IdxAndMemberSequence, size_t i) {
llvm::SmallVector<ProfileID, 2>& IDSequence, size_t i) {
if (i == 0) {
setIsRequired(targetData);
return;
}
--i;
IdxOrMember& curIdxOrMember = IdxAndMemberSequence[i];
if (curIdxOrMember.type == IdxOrMember::IdxOrMemberType::FIELD) {
overlay((*targetData.val.m_ObjData)[curIdxOrMember.val.field],
IdxAndMemberSequence, i);
} else if (curIdxOrMember.type == IdxOrMember::IdxOrMemberType::INDEX) {
auto idx = curIdxOrMember.val.index;
if (eqAPInt(idx, llvm::APInt(2, -1, true)))
for (auto& pair : *targetData.val.m_ArrData)
overlay(pair.second, IdxAndMemberSequence, i);
else
overlay((*targetData.val.m_ArrData)[idx], IdxAndMemberSequence, i);
ProfileID& curID = IDSequence[i];
// non-constant indices are represented with default ID.
ProfileID nonConstIdxID;
if (curID == nonConstIdxID) {
for (auto& pair : *targetData.val.m_ArrData)
overlay(pair.second, IDSequence, i);
} else {
overlay((*targetData.val.m_ArrData)[curID], IDSequence, i);

Check warning on line 91 in lib/Differentiator/TBRAnalyzer.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/TBRAnalyzer.cpp#L91

Added line #L91 was not covered by tests
}
}

Expand All @@ -120,7 +106,7 @@ TBRAnalyzer::VarData* TBRAnalyzer::getMemberVarData(const clang::MemberExpr* ME,
if (nonConstIndexFound && !addNonConstIdx)
return baseData;

return &(*baseData->val.m_ObjData)[FD];
return &(*baseData->val.m_ArrData)[getProfileID(FD)];
}
return nullptr;

Check warning on line 111 in lib/Differentiator/TBRAnalyzer.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/TBRAnalyzer.cpp#L111

Added line #L111 was not covered by tests
}
Expand All @@ -129,13 +115,12 @@ TBRAnalyzer::VarData*
TBRAnalyzer::getArrSubVarData(const clang::ArraySubscriptExpr* ASE,
bool addNonConstIdx) {
const auto* idxExpr = ASE->getIdx();
llvm::APInt idx;
ProfileID idxID;
if (const auto* IL = dyn_cast<IntegerLiteral>(idxExpr)) {
idx = IL->getValue();
idxID = getProfileID(IL);
} else {
nonConstIndexFound = true;
/// Non-const indices are represented with -1.
idx = llvm::APInt(2, -1, true);
/// Non-const indices are represented with default FoldingSetNodeID.
}

const auto* base = ASE->getBase()->IgnoreImpCasts();
Expand All @@ -150,15 +135,16 @@ TBRAnalyzer::getArrSubVarData(const clang::ArraySubscriptExpr* ASE,
return baseData;

auto* baseArrMap = baseData->val.m_ArrData;
auto it = baseArrMap->find(idx);
auto it = baseArrMap->find(idxID);

/// Add the current index if it was not added previously
if (it == baseArrMap->end()) {
auto& idxData = (*baseArrMap)[idx];
/// Since -1 represents non-const indices, whenever we add a new index we
/// have to copy the VarData of -1's element (if an element with undefined
/// index was used this might be our current element).
idxData = copy((*baseArrMap)[llvm::APInt(2, -1, true)]);
auto& idxData = (*baseArrMap)[idxID];
/// Since default ID represents non-const indices, whenever we add a new
/// index we have to copy the VarData of default ID's element (if an element
/// with undefined index was used this might be our current element).
ProfileID nonConstIdxID;
idxData = copy((*baseArrMap)[nonConstIdxID]);
return &idxData;
}

Expand Down Expand Up @@ -209,40 +195,41 @@ TBRAnalyzer::VarData::VarData(const QualType QT) {
elemType = pointerType->getPointeeType().getTypePtrOrNull();
else
elemType = QT->getArrayElementTypeNoTypeQual();
auto& idxData = (*val.m_ArrData)[llvm::APInt(2, -1, true)];
ProfileID nonConstIdxID;
auto& idxData = (*val.m_ArrData)[nonConstIdxID];
idxData = VarData (QualType::getFromOpaquePtr(elemType));
} else if (QT->isBuiltinType()) {
type = VarData::FUND_TYPE;
val.m_FundData = false;
} else if (QT->isRecordType()) {
type = VarData::OBJ_TYPE;
const auto* recordDecl = QT->getAs<RecordType>()->getDecl();
auto& newObjMap = val.m_ObjData;
newObjMap = new ObjMap();
auto& newArrMap = val.m_ArrData;
newArrMap = new ArrMap();
for (const auto* field : recordDecl->fields()) {
const auto varType = field->getType();
(*newObjMap)[field] = VarData(varType);
(*newArrMap)[getProfileID(field)] = VarData(varType);
}
}
}

void TBRAnalyzer::overlay(const clang::Expr* E) {
nonConstIndexFound = false;
llvm::SmallVector<IdxOrMember, 2> IdxAndMemberSequence;
llvm::SmallVector<ProfileID, 2> IDSequence;
const clang::DeclRefExpr* innermostDRE;
bool cond = true;
/// Unwrap the given expression to a vector of indices and fields.
while (cond) {
E = E->IgnoreImplicit();
if (const auto* ASE = dyn_cast<clang::ArraySubscriptExpr>(E)) {
if (const auto* IL = dyn_cast<clang::IntegerLiteral>(ASE->getIdx()))
IdxAndMemberSequence.push_back(IdxOrMember(IL->getValue()));
IDSequence.push_back(getProfileID(IL));

Check warning on line 226 in lib/Differentiator/TBRAnalyzer.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/TBRAnalyzer.cpp#L226

Added line #L226 was not covered by tests
else
IdxAndMemberSequence.push_back(IdxOrMember(llvm::APInt(2, -1, true)));
IDSequence.push_back(ProfileID());
E = ASE->getBase();
} else if (const auto* ME = dyn_cast<clang::MemberExpr>(E)) {
if (const auto* FD = dyn_cast<clang::FieldDecl>(ME->getMemberDecl()))
IdxAndMemberSequence.push_back(IdxOrMember(FD));
IDSequence.push_back(getProfileID(FD));
E = ME->getBase();

Check warning on line 233 in lib/Differentiator/TBRAnalyzer.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/TBRAnalyzer.cpp#L231-L233

Added lines #L231 - L233 were not covered by tests
} else if (isa<clang::DeclRefExpr>(E)) {
innermostDRE = dyn_cast<clang::DeclRefExpr>(E);
Expand All @@ -253,8 +240,7 @@ void TBRAnalyzer::overlay(const clang::Expr* E) {

/// Overlay on all the VarData's recursively.
if (const auto* VD = dyn_cast<clang::VarDecl>(innermostDRE->getDecl())) {
overlay(getCurBranch()[VD], IdxAndMemberSequence,
IdxAndMemberSequence.size());
overlay(getCurBranch()[VD], IDSequence, IDSequence.size());
}
}

Expand Down

0 comments on commit c9ba780

Please sign in to comment.