forked from onnx/onnx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathslice_9_10.h
49 lines (40 loc) · 1.33 KB
/
slice_9_10.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Slice in default domain from version 9 to 10
#pragma once
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Slice_9_10 final : public Adapter {
public:
explicit Slice_9_10() : Adapter("Slice", OpSetID(9), OpSetID(10)) {}
void attrToInput(std::shared_ptr<Graph> graph, Node* node, const std::vector<int64_t>& attr) const {
Tensor t;
t.elem_type() = TensorProto_DataType_INT64;
t.sizes() = std::vector<int64_t>{static_cast<int64_t>(attr.size())};
auto& data = t.int64s();
for (auto a : attr) {
data.emplace_back(a);
}
Node* constant = graph->create(kConstant);
constant->insertBefore(node);
constant->t_(kvalue, t);
node->addInput(constant->output());
}
void adapt_slice_9_10(std::shared_ptr<Graph> graph, Node* node) const {
attrToInput(graph, node, node->is(kstarts));
node->removeAttribute(kstarts);
attrToInput(graph, node, node->is(kends));
node->removeAttribute(kends);
if (node->hasAttribute(kaxes)) {
attrToInput(graph, node, node->is(kaxes));
node->removeAttribute(kaxes);
}
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_slice_9_10(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE