Skip to content

Commit a71faba

Browse files
Revert "Dnt CSE across context managers"
This reverts commit 0981b01. Reverted pytorch#76075 on behalf of https://github.com/seemethere
1 parent 1324410 commit a71faba

File tree

2 files changed

+5
-75
lines changed

2 files changed

+5
-75
lines changed

Diff for: test/test_jit.py

-13
Original file line numberDiff line numberDiff line change
@@ -1232,19 +1232,6 @@ def fn(x, y):
12321232

12331233
self.assertExportImport(g, (x, y))
12341234

1235-
def test_cse_context_managers(self):
1236-
def bar(x):
1237-
with torch.no_grad():
1238-
y = x * 2
1239-
1240-
z = 2 * x + y
1241-
return z, x.requires_grad, y.requires_grad, z.requires_grad
1242-
1243-
a = torch.rand(3, requires_grad=True)
1244-
b = torch.rand(3, requires_grad=True)
1245-
1246-
self.checkScript(bar, (a,))
1247-
12481235
def test_cse_not_introduce_aliasing(self):
12491236
@torch.jit.script
12501237
def tensor_alias_outputs(x):

Diff for: torch/csrc/jit/passes/common_subexpression_elimination.cpp

+5-62
Original file line numberDiff line numberDiff line change
@@ -5,51 +5,12 @@
55
#include <torch/csrc/jit/ir/node_hashing.h>
66
#include <torch/csrc/jit/jit_log.h>
77

8-
#include <c10/util/hash.h>
98
#include <unordered_map>
109

1110
namespace torch {
1211
namespace jit {
1312
namespace {
1413

15-
// There are a context managers which change global state -
16-
// with torch.no_grad(), with torch.cpu.amp.autocast
17-
// These are represented in JIT as prim::Enter and prim::Exit nodes
18-
// Avoid CSE across two separate with statements
19-
20-
struct NodeAndContextNode : public std::pair<Node*, Node*> {
21-
using pair::pair;
22-
23-
Node* node() const {
24-
return this->first;
25-
}
26-
Node* contextNode() const {
27-
return this->second;
28-
}
29-
};
30-
31-
struct TORCH_API HashNodeAndContext {
32-
size_t operator()(const NodeAndContextNode pair) const {
33-
HashNode hash;
34-
// we hash on the properties of the Node to be CSE'd, and the exact node
35-
// of the context its in (which may be a nullptr)
36-
return c10::hash_combine(
37-
hash(pair.node()), reinterpret_cast<size_t>(pair.contextNode()));
38-
}
39-
};
40-
41-
struct TORCH_API EqualNodeAndContext {
42-
bool operator()(const NodeAndContextNode lhs, const NodeAndContextNode rhs)
43-
const {
44-
EqualNode eq;
45-
bool nodes_equal = eq(lhs.node(), rhs.node());
46-
// similarly to equality, check equality of properties of the nodes to be
47-
// CSE'd
48-
// and the exact node of its context (which may be nullptr)
49-
return nodes_equal && (lhs.contextNode() == rhs.contextNode());
50-
}
51-
};
52-
5314
struct CommonSubexpressionEliminator {
5415
CommonSubexpressionEliminator(std::shared_ptr<Graph> graph)
5516
: graph_(std::move(graph)) {}
@@ -62,22 +23,11 @@ struct CommonSubexpressionEliminator {
6223
// Since the nodes are visited in topological order, one pass is enough.
6324
// returns true if CSE made changes to a graph
6425
bool run(Block* block, std::function<Node*(Node*)> parent_lookup_fn) {
65-
std::unordered_set<
66-
NodeAndContextNode,
67-
HashNodeAndContext,
68-
EqualNodeAndContext>
69-
subexprs;
26+
std::unordered_set<Node*, HashNode, EqualNode> subexprs;
7027
bool changed = false;
7128
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
7229
auto node = *it;
7330

74-
if (node->kind() == prim::Enter) {
75-
prim_enter_stack_.push_back(node);
76-
}
77-
if (node->kind() == prim::Exit) {
78-
prim_enter_stack_.pop_back();
79-
}
80-
8131
if (node->kind() == prim::profile) {
8232
GRAPH_DEBUG(
8333
"Profiled nodes shouldn't be CSE'ed there's a separate pass that does dedup and merging:\n",
@@ -98,13 +48,9 @@ struct CommonSubexpressionEliminator {
9848
// Traverse sub-blocks.
9949
for (auto block : node->blocks()) {
10050
changed |= run(block, [&](Node* n) {
101-
NodeAndContextNode nacn(
102-
n,
103-
prim_enter_stack_.size() == 0 ? nullptr
104-
: prim_enter_stack_.back());
105-
auto existing = subexprs.find(nacn);
51+
auto existing = subexprs.find(n);
10652
if (existing != subexprs.end()) {
107-
return (*existing).node();
53+
return *existing;
10854
}
10955

11056
return parent_lookup_fn(n);
@@ -137,12 +83,10 @@ struct CommonSubexpressionEliminator {
13783
}
13884

13985
// Check whether the same subexpression already exists.
140-
Node* enter_node =
141-
prim_enter_stack_.size() == 0 ? nullptr : prim_enter_stack_.back();
142-
auto subit = subexprs.insert(NodeAndContextNode(node, enter_node));
86+
auto subit = subexprs.insert(node);
14387
if (!subit.second) {
14488
// Subexpression exists, replace the uses of node, and destroy it.
145-
auto existing = (*subit.first).node();
89+
auto existing = *subit.first;
14690

14791
// don't introduce new aliasing among graph outputs
14892
if (getOrCreateAliasDb().mayContainAlias(
@@ -171,7 +115,6 @@ struct CommonSubexpressionEliminator {
171115
}
172116

173117
private:
174-
std::vector<Node*> prim_enter_stack_;
175118
std::unique_ptr<AliasDb> alias_db_;
176119
std::shared_ptr<Graph> graph_;
177120
};

0 commit comments

Comments
 (0)