forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcanonicalize_graph_fuser_ops.cpp
98 lines (92 loc) · 3.74 KB
/
canonicalize_graph_fuser_ops.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include <c10/util/irange.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
namespace torch::jit {
struct ChunkOutput {
ChunkOutput(Value* v, size_t o) : val(v), offset(o){};
Value* val;
size_t offset;
};
static std::optional<std::vector<ChunkOutput>> getChunkOutputs(Node* chunk) {
std::vector<ChunkOutput> outputs;
for (auto list_use : chunk->output()->uses()) {
if (list_use.user->matches(
"aten::select(t[] list, int idx) -> t", attr::idx) &&
list_use.user->output()->type()->cast<TensorType>()) {
outputs.emplace_back(
list_use.user->output(),
list_use.user->get<int64_t>(attr::idx).value());
} else if (list_use.user->kind() == prim::ListUnpack) {
// This sometimes happens if the sizes can't be evenly divided by the
// number of chunks
if (static_cast<int64_t>(list_use.user->outputs().size()) !=
chunk->get<int64_t>(attr::chunks).value()) {
return std::nullopt;
}
auto unpack_outputs = list_use.user->outputs();
for (const auto i : c10::irange(unpack_outputs.size())) {
outputs.emplace_back(unpack_outputs[i], i);
}
} else {
return std::nullopt;
}
}
return outputs;
}
static void CanonicalizeOps(Block* block) {
for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
++it) {
for (auto sub : it->blocks())
CanonicalizeOps(sub);
if (it->matches(
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
it->matches(
"aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
it->matches("aten::mul(Tensor self, Tensor other) -> Tensor") ||
it->matches("aten::div(Tensor self, Tensor other) -> Tensor")) {
// Replace rank 0 Tensor constants with scalar constants.
if (auto other = it->get<at::Tensor>(attr::other)) {
if (other->dim() == 0) {
WithInsertPoint insert_guard{*it};
auto graph = it->owningGraph();
auto new_other = graph->insertConstant(other->item());
std::vector<Value*> inputs = it->inputs().vec();
inputs.at(1) = new_other;
Value* new_output =
graph->insertNode(graph->create(it->kind(), inputs))->output();
new_output->node()->copyMetadata(*it);
new_output->copyMetadata(it->output());
it->output()->replaceAllUsesWith(new_output);
}
}
} else if (it->matches(
"aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
/*const_inputs=*/{attr::chunks, attr::dim})) {
// Replace aten::chunk (which returns a list) with ConstantChunk with the
// outputs unpacked.
if (auto orig_outputs = getChunkOutputs(*it)) {
WithInsertPoint guard(*it);
auto* self = it->namedInput(attr::self);
auto* graph = it->owningGraph();
const auto chunks = it->get<int64_t>(attr::chunks).value();
const auto dim = it->get<int64_t>(attr::dim).value();
auto* node =
graph->insertNode(graph->create(prim::ConstantChunk, chunks));
node->addInput(self);
node->i_(attr::chunks, chunks)->i_(attr::dim, dim);
node->copyMetadata(*it);
for (const auto& orig_out : *orig_outputs) {
orig_out.val->replaceAllUsesWith(node->outputs()[orig_out.offset]);
node->outputs()[orig_out.offset]->setType(orig_out.val->type());
}
}
}
}
}
void CanonicalizeOps(const std::shared_ptr<Graph>& graph) {
CanonicalizeOps(graph->block());
GRAPH_DUMP("After CanonicalizeOps: ", graph);
EliminateDeadCode(graph);
}
} // namespace torch::jit