@@ -89,6 +89,42 @@ struct PeepholeOptimizeListIdiomsImpl {
89
89
}
90
90
}
91
91
}
92
+ } else if (node->kind () == prim::ListUnpack) {
93
+ auto list_creation_node = first_input->node ();
94
+ if (list_creation_node->kind () == prim::ListConstruct) {
95
+ // if sizes are unequal it's a runtime error
96
+ if (list_creation_node->inputs ().size () != node->outputs ().size ()) {
97
+ continue ;
98
+ }
99
+ for (size_t i = 0 ; i < node->outputs ().size (); ++i) {
100
+ node->output (i)->replaceAllUsesWith (list_creation_node->inputs ().at (i));
101
+ }
102
+ }
103
+ } else if (node->kind () == aten::add) {
104
+ if (node->inputs ().size () != 2 ) {
105
+ continue ;
106
+ }
107
+ auto second_input = node->inputs ().at (1 );
108
+ // already checked first, need to check second
109
+ if (mutated_lists_.count (second_input)) {
110
+ continue ;
111
+ }
112
+ if (first_input->node ()->kind () != prim::ListConstruct || second_input->node ()->kind () != prim::ListConstruct) {
113
+ continue ;
114
+ }
115
+ WithInsertPoint guard (node);
116
+ auto list_construct = graph_->insertNode (graph_->create (prim::ListConstruct));
117
+ list_construct->output ()->setType (node->output ()->type ());
118
+ for (Value * v: first_input->node ()->inputs ()) {
119
+ list_construct->addInput (v);
120
+ }
121
+ for (Value * v: second_input->node ()->inputs ()) {
122
+ list_construct->addInput (v);
123
+ }
124
+ node->output ()->replaceAllUsesWith (list_construct->output ());
125
+ if (mutated_lists_.count (node->output ())) {
126
+ mutated_lists_.insert (list_construct->output ());
127
+ }
92
128
}
93
129
}
94
130
}
0 commit comments