This repository has been archived by the owner on Mar 1, 2023. It is now read-only.
forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluator.hpp
179 lines (157 loc) · 5.93 KB
/
evaluator.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <map>
#include <stack>
#include <utility>
#include "ngraph/deprecated.hpp"
#include "ngraph/node.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type_traits.hpp"
namespace ngraph {
/// \brief Execute handlers on a subgraph to compute values
///
///
template <typename V>
class NGRAPH_DEPRECATED("This class is deprecated and will be removed soon.") Evaluator {
NGRAPH_SUPPRESS_DEPRECATED_START
public:
/// \brief values we compute for outputs
using value_map = std::map<RawNodeOutput, V>;
/// \brief Handler for a computation of a value about an op
///
/// A handler is passed a Node* and a vector of computed input values. The handler should
/// return a vector of computed output values.
using op_handler = std::function<std::vector<V>(Node* op, std::vector<V>& inputs)>;
/// \brief Table of ops with handlers
using op_handler_map = std::map<Node::type_info_t, op_handler>;
/// \brief construct handler using the provided op handlers.
///
/// Evaluations share previously computed values so that calls on multiple nodes can share
/// work. All state is kept in the value map, which is accessible for clearing or seeding
/// with
/// Evaluator::get_value_map().
///
/// \param Handlers for ops. Pairs of Node::type_info_t and handler functions.
Evaluator(const op_handler_map& handlers, value_map& values) : m_handlers(handlers), m_value_map(values) {}
/// \brief Retrieves the value_map, which holds all Output<Node> value associations.
value_map& get_value_map() {
return m_value_map;
}
const value_map& get_value_map() const {
return m_value_map;
}
/// \brief If set, handles all ops
const op_handler& get_univeral_handler() const {
return m_universal_handler;
}
/// \brief If set, handles all ops not in the handlers
const op_handler& get_default_handler() const {
return m_default_handler;
}
/// \brief If set, handles all ops
void set_universal_handler(const op_handler& handler) {
m_universal_handler = handler;
}
/// \brief If set, handles all ops not in the handlers
void set_default_handler(const op_handler& handler) {
m_default_handler = handler;
}
protected:
op_handler get_handler(Node* node) {
op_handler handler = m_universal_handler;
if (!handler) {
auto it = m_handlers.find(node->get_type_info());
if (it == m_handlers.end()) {
handler = m_default_handler;
} else {
handler = it->second;
}
}
return handler;
}
class Inst;
using InstPtr = std::unique_ptr<Inst>;
using InstStack = std::stack<InstPtr>;
/// \brief Intstructions for evaluations state machine
class Inst {
protected:
Inst(Node* node) : m_node(node) {}
public:
virtual ~Inst() {}
virtual void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) = 0;
Node* get_node() {
return m_node;
}
protected:
Node* m_node;
};
/// \brief Ensure value has been analyzed
class ValueInst : public Inst {
public:
ValueInst(const Output<Node>& value) : Inst(value.get_node()), m_index(value.get_index()) {}
ValueInst(const RawNodeOutput& value) : Inst(value.node), m_index(value.index) {}
void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override {
// Request to analyze this value if we can
if (auto handler = evaluator.get_handler(node)) {
// Ensure the inputs are processed and then execute the op handler
inst_stack.push(InstPtr(new ExecuteInst(node, handler)));
for (auto v : node->input_values()) {
inst_stack.push(InstPtr(new ValueInst(v)));
}
} else {
// We don't know how to handle this op, so mark the outputs as unknown
for (auto output : node->outputs()) {
evaluator.get_value_map()[output] = V();
}
}
}
private:
int64_t m_index;
};
/// \brief All arguments have been handled; execute the node handler
class ExecuteInst : public Inst {
public:
ExecuteInst(Node* node, op_handler& handler) : Inst(node), m_handler(handler) {}
void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override {
// Request to execute the handleer. Pass what we know about the inputs to the
// handler and associate the results with the outputs
std::vector<V> inputs;
for (auto v : node->input_values()) {
inputs.push_back(evaluator.get_value_map().at(v));
}
std::vector<V> outputs = m_handler(node, inputs);
for (size_t i = 0; i < outputs.size(); ++i) {
evaluator.get_value_map()[node->output(i)] = outputs[i];
}
}
private:
op_handler m_handler;
};
public:
/// \brief Determine information about value
V evaluate(const Output<Node>& value) {
InstStack inst_stack;
inst_stack.push(InstPtr(new ValueInst(value)));
while (!inst_stack.empty()) {
InstPtr inst;
std::swap(inst_stack.top(), inst);
inst_stack.pop();
auto node = inst->get_node();
if (m_value_map.find(node->output(0)) != m_value_map.end()) {
// Already computed
continue;
}
inst->handle(*this, inst_stack, node);
}
return m_value_map.at(value);
}
protected:
op_handler m_universal_handler;
op_handler_map m_handlers;
op_handler m_default_handler;
value_map& m_value_map;
NGRAPH_SUPPRESS_DEPRECATED_END
};
} // namespace ngraph