Skip to content

Commit

Permalink
Make the tbr header local
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Nov 23, 2023
1 parent d48dbae commit d1bab8d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 23 deletions.
2 changes: 1 addition & 1 deletion lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

#include "ConstantFolder.h"

#include "TBRAnalyzer.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/ExternalRMVSource.h"
#include "clad/Differentiator/MultiplexExternalRMVSource.h"
#include "clad/Differentiator/StmtClone.h"
#include "clad/Differentiator/TBRAnalyzer.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"
Expand Down
2 changes: 1 addition & 1 deletion lib/Differentiator/TBRAnalyzer.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "clad/Differentiator/TBRAnalyzer.h"
#include "TBRAnalyzer.h"

using namespace clang;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor<TBRAnalyzer> {
/// and object fields.
using ProfileID = clad_compat::FoldingSetNodeID;

ProfileID getProfileID(const Expr* E) const{
ProfileID getProfileID(const Expr* E) const {
ProfileID profID;
E->Profile(profID, m_Context, /* Canonical */ true);
return profID;
Expand All @@ -40,22 +40,19 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor<TBRAnalyzer> {
}

struct ProfileIDHash {
size_t operator()(const ProfileID& x) const {
return x.ComputeHash();
}
size_t operator()(const ProfileID& x) const { return x.ComputeHash(); }
};


struct VarData;
using ArrMap =
std::unordered_map<const ProfileID, VarData, ProfileIDHash>;
using ArrMap = std::unordered_map<const ProfileID, VarData, ProfileIDHash>;

// NOLINTBEGIN(cppcoreguidelines-pro-type-union-access)

/// Stores all the necessary information about one variable. Fundamental type
/// variables need only one bit. An object/array needs a separate VarData for
/// each field/element. Reference type variables store the clang::Expr* they
/// refer to. UNDEFINED is used whenever the type of a node cannot be determined.
/// refer to. UNDEFINED is used whenever the type of a node cannot be
/// determined.
///
/// FIXME: Pointers to objects are considered OBJ_TYPE for simplicity. This
/// approach might cause problems when the support for pointers is added.
Expand All @@ -82,7 +79,7 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor<TBRAnalyzer> {

VarData() = default;
VarData(const VarData& other) = delete;
VarData(VarData&& other) noexcept: type(other.type) {
VarData(VarData&& other) noexcept : type(other.type) {
*this = std::move(other);
}
VarData& operator=(const VarData& other) = delete;
Expand All @@ -105,9 +102,8 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor<TBRAnalyzer> {

/// Erases all children VarData's of this VarData.
~VarData() {
if (type == OBJ_TYPE || type == ARR_TYPE) {
if (type == OBJ_TYPE || type == ARR_TYPE)
val.m_ArrData.reset();
}
}
};
// NOLINTEND(cppcoreguidelines-pro-type-union-access)
Expand All @@ -120,8 +116,7 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor<TBRAnalyzer> {
/// when 'a[k].y' is set to required). Takes unwrapped sequence of
/// indices/members of the expression being overlaid and the index of of the
/// current index/member.
void overlay(VarData& targetData,
llvm::SmallVector<ProfileID, 2>& IDSequence,
void overlay(VarData& targetData, llvm::SmallVector<ProfileID, 2>& IDSequence,
size_t i);
/// Returns true if there is at least one required to store node among
/// child nodes.
Expand Down Expand Up @@ -173,7 +168,8 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor<TBRAnalyzer> {
VarsData() = default;
VarsData(const VarsData& other) = default;
~VarsData() = default;
VarsData(VarsData&& other) noexcept : data(std::move(other.data)), prev(other.prev) {}
VarsData(VarsData&& other) noexcept
: data(std::move(other.data)), prev(other.prev) {}
VarsData& operator=(const VarsData& other) = delete;
VarsData& operator=(VarsData&& other) noexcept {
if (&data == &other.data) {
Expand All @@ -189,24 +185,24 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor<TBRAnalyzer> {
iterator end() { return data.end(); }
VarData& operator[](const clang::VarDecl* VD) { return data[VD]; }
iterator find(const clang::VarDecl* VD) { return data.find(VD); }
void clear() {
data.clear();
}
void clear() { data.clear(); }
};


/// Collects the data from 'varsData' and its predecessors until
/// 'limit' into one map ('limit' VarsData is not included).
/// If 'limit' is 'nullptr', data is collected starting with
/// the entry CFG block.
/// Note: the returned VarsData contains original data from
/// the predecessors (NOT copies). It should not be modified.
std::unordered_map<const clang::VarDecl*, VarData*>
static collectDataFromPredecessors(VarsData* varsData, VarsData* limit = nullptr);
std::unordered_map<
const clang::VarDecl*,
VarData*> static collectDataFromPredecessors(VarsData* varsData,
VarsData* limit = nullptr);

/// Finds the lowest common ancestor of two VarsData
/// (based on the prev field in VarsData).
static VarsData* findLowestCommonAncestor(VarsData* varsData1, VarsData* varsData2);
static VarsData* findLowestCommonAncestor(VarsData* varsData1,
VarsData* varsData2);

/// Merges mergeData into targetData. Should be called
/// after mergeData is passed and the corresponding CFG
Expand Down

0 comments on commit d1bab8d

Please sign in to comment.