Skip to content

Commit 9058040

Browse files
Elias Ellisonfacebook-github-bot
Elias Ellison
authored andcommitted
Add more list peephole idioms (pytorch#48268)
Summary: Pull Request resolved: pytorch#48268 Test Plan: Imported from OSS Reviewed By: jamesr66a Differential Revision: D25104617 Pulled By: eellison fbshipit-source-id: b41c03d5da6e9b88acf21a859f61c5c70608c150
1 parent 39d3578 commit 9058040

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

torch/csrc/jit/passes/peephole_list_idioms.cpp

+36
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,42 @@ struct PeepholeOptimizeListIdiomsImpl {
8989
}
9090
}
9191
}
92+
} else if (node->kind() == prim::ListUnpack) {
93+
auto list_creation_node = first_input->node();
94+
if (list_creation_node->kind() == prim::ListConstruct) {
95+
// if sizes are unequal it's a runtime error
96+
if (list_creation_node->inputs().size() != node->outputs().size()) {
97+
continue;
98+
}
99+
for (size_t i = 0; i < node->outputs().size(); ++i) {
100+
node->output(i)->replaceAllUsesWith(list_creation_node->inputs().at(i));
101+
}
102+
}
103+
} else if (node->kind() == aten::add) {
104+
if (node->inputs().size() != 2) {
105+
continue;
106+
}
107+
auto second_input = node->inputs().at(1);
108+
// already checked first, need to check second
109+
if (mutated_lists_.count(second_input)) {
110+
continue;
111+
}
112+
if (first_input->node()->kind() != prim::ListConstruct || second_input->node()->kind() != prim::ListConstruct) {
113+
continue;
114+
}
115+
WithInsertPoint guard(node);
116+
auto list_construct = graph_->insertNode(graph_->create(prim::ListConstruct));
117+
list_construct->output()->setType(node->output()->type());
118+
for (Value * v: first_input->node()->inputs()) {
119+
list_construct->addInput(v);
120+
}
121+
for (Value * v: second_input->node()->inputs()) {
122+
list_construct->addInput(v);
123+
}
124+
node->output()->replaceAllUsesWith(list_construct->output());
125+
if (mutated_lists_.count(node->output())) {
126+
mutated_lists_.insert(list_construct->output());
127+
}
92128
}
93129
}
94130
}

0 commit comments

Comments
 (0)