From bb87743965123e7b57e3437d255d10467e90416c Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Sun, 22 Dec 2024 16:47:43 +0000 Subject: [PATCH] Improve DiffRequest and DynamicGraph printing. --- include/clad/Differentiator/DiffPlanner.h | 19 +++++++---------- include/clad/Differentiator/Differentiator.h | 1 - include/clad/Differentiator/DynamicGraph.h | 22 +++++++++++++------- lib/Differentiator/DiffPlanner.cpp | 15 +++++++++++++ tools/ClangPlugin.cpp | 2 +- unittests/Misc/DynamicGraph.cpp | 15 ++++++++----- 6 files changed, 48 insertions(+), 26 deletions(-) diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 30b483b7e..4c4f213f6 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -1,14 +1,18 @@ #ifndef CLAD_DIFF_PLANNER_H #define CLAD_DIFF_PLANNER_H -#include "clang/AST/RecursiveASTVisitor.h" -#include "llvm/ADT/SmallSet.h" #include "clad/Differentiator/DiffMode.h" #include "clad/Differentiator/DynamicGraph.h" #include "clad/Differentiator/ParseDiffArgsTypes.h" +#include "clang/AST/RecursiveASTVisitor.h" + +#include "llvm/ADT/SmallSet.h" +#include + #include #include + namespace clang { class CallExpr; class CompilerInstance; @@ -132,15 +136,8 @@ struct DiffRequest { const clang::FunctionDecl* operator->() const { return Function; } - // String operator for printing the node. - operator std::string() const { - std::string res = BaseFunctionName + "__order_" + - std::to_string(CurrentDerivativeOrder) + "__mode_" + - DiffModeToString(Mode); - if (EnableTBRAnalysis) - res += "__TBR"; - return res; - } + void print(llvm::raw_ostream& Out) const; + void dump() const { print(llvm::errs()); } bool shouldBeRecorded(clang::Expr* E) const; bool shouldHaveAdjoint(const clang::VarDecl* VD) const; diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index c8aaaa286..60e4b014c 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -14,7 +14,6 @@ #include "BuiltinDerivativesCUDA.cuh" #endif #include "CladConfig.h" -#include "DynamicGraph.h" #include "FunctionTraits.h" #include "Matrix.h" #include "NumericalDiff.h" diff --git a/include/clad/Differentiator/DynamicGraph.h b/include/clad/Differentiator/DynamicGraph.h index 2ef8cf992..50b799cb7 100644 --- a/include/clad/Differentiator/DynamicGraph.h +++ b/include/clad/Differentiator/DynamicGraph.h @@ -1,6 +1,8 @@ #ifndef CLAD_DIFFERENTIATOR_DYNAMICGRAPH_H #define CLAD_DIFFERENTIATOR_DYNAMICGRAPH_H +#include "llvm/Support/raw_ostream.h" + #include #include #include @@ -109,23 +111,27 @@ template class DynamicGraph { const std::vector& getNodes() const { return m_nodes; } std::vector& getNodes() { return m_nodes; } + /// Dump the nodes and edges. + void dump() const { print(llvm::errs()); } + /// Print the nodes and edges in the graph. - void print() { + void print(llvm::raw_ostream& Out) const { // First print the nodes with their insertion order. for (const T& node : m_nodes) { - std::pair nodeInfo = m_nodeMap[node]; - std::cout << (std::string)node << ": #" << nodeInfo.second; + std::pair nodeInfo = m_nodeMap.at(node); + node.print(Out); + Out << ": #" << nodeInfo.second; if (m_sources.find(nodeInfo.second) != m_sources.end()) - std::cout << " (source)"; + Out << " (source)"; if (nodeInfo.first) - std::cout << ", (done)\n"; + Out << ", (done)\n"; else - std::cout << ", (unprocessed)\n"; + Out << ", (unprocessed)\n"; } // Then print the edges. for (int i = 0; i < m_nodes.size(); i++) - for (size_t dest : m_adjList[i]) - std::cout << i << " -> " << dest << "\n"; + for (size_t dest : m_adjList.at(i)) + Out << i << " -> " << dest << "\n"; } /// Get the next node to be processed from the queue of nodes to be diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index d2c39c1d9..cd90e0c01 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -1,5 +1,7 @@ #include "clad/Differentiator/DiffPlanner.h" +#include "clad/Differentiator/DiffMode.h" + #include "ActivityAnalyzer.h" #include "TBRAnalyzer.h" @@ -601,6 +603,19 @@ namespace clad { return; } + void DiffRequest::print(llvm::raw_ostream& Out) const { + Out << '<'; + PrintingPolicy Policy(Function->getASTContext().getLangOpts()); + Function->getNameForDiagnostic(Out, Policy, /*Qualified=*/true); + Out << ">[name=" << BaseFunctionName << ", " + << "order=" << CurrentDerivativeOrder << ", " + << "mode=" << DiffModeToString(Mode); + if (EnableTBRAnalysis) + Out << ", tbr"; + Out << ']'; + Out.flush(); + } + bool DiffRequest::shouldBeRecorded(Expr* E) const { if (!EnableTBRAnalysis) return true; diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 9e1977e0b..d446d9d49 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -526,7 +526,7 @@ namespace clad { // Print the graph of the diff requests. llvm::errs() << "\n*** INFORMATION ABOUT THE DIFF REQUESTS\n"; - m_DiffRequestGraph.print(); + m_DiffRequestGraph.print(llvm::errs()); m_Multiplexer->PrintStats(); } diff --git a/unittests/Misc/DynamicGraph.cpp b/unittests/Misc/DynamicGraph.cpp index 6954a6698..b925b86c9 100644 --- a/unittests/Misc/DynamicGraph.cpp +++ b/unittests/Misc/DynamicGraph.cpp @@ -1,4 +1,7 @@ #include "clad/Differentiator/Differentiator.h" +#include "clad/Differentiator/DynamicGraph.h" + +#include "llvm/Support/raw_ostream.h" #include #include @@ -16,7 +19,9 @@ struct Node { } // String operator for printing the node. - operator std::string() const { return name + std::to_string(id); } + void print(llvm::raw_ostream &Out) const { + Out << name << std::to_string(id); + } }; // Specialize std::hash for the Node type. @@ -44,10 +49,10 @@ TEST(DynamicGraphTest, Printing) { // Check the printed output. std::stringstream ss; - std::streambuf* coutbuf = std::cout.rdbuf(); - std::cout.rdbuf(ss.rdbuf()); - G.print(); - std::cout.rdbuf(coutbuf); + std::streambuf* coutbuf = std::cerr.rdbuf(); + std::cerr.rdbuf(ss.rdbuf()); + G.dump(); + std::cerr.rdbuf(coutbuf); std::string expectedOutput = "node0: #0 (source), (unprocessed)\n" "node1: #1, (unprocessed)\n" "node2: #2, (unprocessed)\n"