5
5
#include < torch/csrc/jit/ir/node_hashing.h>
6
6
#include < torch/csrc/jit/jit_log.h>
7
7
8
- #include < c10/util/hash.h>
9
8
#include < unordered_map>
10
9
11
10
namespace torch {
12
11
namespace jit {
13
12
namespace {
14
13
15
- // There are a context managers which change global state -
16
- // with torch.no_grad(), with torch.cpu.amp.autocast
17
- // These are represented in JIT as prim::Enter and prim::Exit nodes
18
- // Avoid CSE across two separate with statements
19
-
20
- struct NodeAndContextNode : public std ::pair<Node*, Node*> {
21
- using pair::pair;
22
-
23
- Node* node () const {
24
- return this ->first ;
25
- }
26
- Node* contextNode () const {
27
- return this ->second ;
28
- }
29
- };
30
-
31
- struct TORCH_API HashNodeAndContext {
32
- size_t operator ()(const NodeAndContextNode pair) const {
33
- HashNode hash;
34
- // we hash on the properties of the Node to be CSE'd, and the exact node
35
- // of the context its in (which may be a nullptr)
36
- return c10::hash_combine (
37
- hash (pair.node ()), reinterpret_cast <size_t >(pair.contextNode ()));
38
- }
39
- };
40
-
41
- struct TORCH_API EqualNodeAndContext {
42
- bool operator ()(const NodeAndContextNode lhs, const NodeAndContextNode rhs)
43
- const {
44
- EqualNode eq;
45
- bool nodes_equal = eq (lhs.node (), rhs.node ());
46
- // similarly to equality, check equality of properties of the nodes to be
47
- // CSE'd
48
- // and the exact node of its context (which may be nullptr)
49
- return nodes_equal && (lhs.contextNode () == rhs.contextNode ());
50
- }
51
- };
52
-
53
14
struct CommonSubexpressionEliminator {
54
15
CommonSubexpressionEliminator (std::shared_ptr<Graph> graph)
55
16
: graph_(std::move(graph)) {}
@@ -62,22 +23,11 @@ struct CommonSubexpressionEliminator {
62
23
// Since the nodes are visited in topological order, one pass is enough.
63
24
// returns true if CSE made changes to a graph
64
25
bool run (Block* block, std::function<Node*(Node*)> parent_lookup_fn) {
65
- std::unordered_set<
66
- NodeAndContextNode,
67
- HashNodeAndContext,
68
- EqualNodeAndContext>
69
- subexprs;
26
+ std::unordered_set<Node*, HashNode, EqualNode> subexprs;
70
27
bool changed = false ;
71
28
for (auto it = block->nodes ().begin (); it != block->nodes ().end (); ++it) {
72
29
auto node = *it;
73
30
74
- if (node->kind () == prim::Enter) {
75
- prim_enter_stack_.push_back (node);
76
- }
77
- if (node->kind () == prim::Exit) {
78
- prim_enter_stack_.pop_back ();
79
- }
80
-
81
31
if (node->kind () == prim::profile) {
82
32
GRAPH_DEBUG (
83
33
" Profiled nodes shouldn't be CSE'ed there's a separate pass that does dedup and merging:\n " ,
@@ -98,13 +48,9 @@ struct CommonSubexpressionEliminator {
98
48
// Traverse sub-blocks.
99
49
for (auto block : node->blocks ()) {
100
50
changed |= run (block, [&](Node* n) {
101
- NodeAndContextNode nacn (
102
- n,
103
- prim_enter_stack_.size () == 0 ? nullptr
104
- : prim_enter_stack_.back ());
105
- auto existing = subexprs.find (nacn);
51
+ auto existing = subexprs.find (n);
106
52
if (existing != subexprs.end ()) {
107
- return ( *existing). node () ;
53
+ return *existing;
108
54
}
109
55
110
56
return parent_lookup_fn (n);
@@ -137,12 +83,10 @@ struct CommonSubexpressionEliminator {
137
83
}
138
84
139
85
// Check whether the same subexpression already exists.
140
- Node* enter_node =
141
- prim_enter_stack_.size () == 0 ? nullptr : prim_enter_stack_.back ();
142
- auto subit = subexprs.insert (NodeAndContextNode (node, enter_node));
86
+ auto subit = subexprs.insert (node);
143
87
if (!subit.second ) {
144
88
// Subexpression exists, replace the uses of node, and destroy it.
145
- auto existing = ( *subit.first ). node () ;
89
+ auto existing = *subit.first ;
146
90
147
91
// don't introduce new aliasing among graph outputs
148
92
if (getOrCreateAliasDb ().mayContainAlias (
@@ -171,7 +115,6 @@ struct CommonSubexpressionEliminator {
171
115
}
172
116
173
117
private:
174
- std::vector<Node*> prim_enter_stack_;
175
118
std::unique_ptr<AliasDb> alias_db_;
176
119
std::shared_ptr<Graph> graph_;
177
120
};
0 commit comments