forked from onnx/onnx-tensorrt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
NvOnnxParser.h
233 lines (206 loc) · 7 KB
/
NvOnnxParser.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
/*
* SPDX-License-Identifier: Apache-2.0
*/
#ifndef NV_ONNX_PARSER_H
#define NV_ONNX_PARSER_H
#include "NvInfer.h"
#include <stddef.h>
#include <vector>
//!
//! \file NvOnnxParser.h
//!
//! This is the API for the ONNX Parser
//!
#define NV_ONNX_PARSER_MAJOR 0
#define NV_ONNX_PARSER_MINOR 1
#define NV_ONNX_PARSER_PATCH 0
static const int NV_ONNX_PARSER_VERSION = ((NV_ONNX_PARSER_MAJOR * 10000) + (NV_ONNX_PARSER_MINOR * 100) + NV_ONNX_PARSER_PATCH);
//! \typedef SubGraph_t
//!
//! \brief The data structure containing the parsing capability of
//! a set of nodes in an ONNX graph.
//!
typedef std::pair<std::vector<size_t>, bool> SubGraph_t;
//! \typedef SubGraphCollection_t
//!
//! \brief The data structure containing all SubGraph_t partitioned
//! out of an ONNX graph.
//!
typedef std::vector<SubGraph_t> SubGraphCollection_t;
//!
//! \namespace nvonnxparser
//!
//! \brief The TensorRT ONNX parser API namespace
//!
namespace nvonnxparser
{
template <typename T>
inline int32_t EnumMax();
/** \enum ErrorCode
*
* \brief the type of parser error
*/
enum class ErrorCode : int
{
kSUCCESS = 0,
kINTERNAL_ERROR = 1,
kMEM_ALLOC_FAILED = 2,
kMODEL_DESERIALIZE_FAILED = 3,
kINVALID_VALUE = 4,
kINVALID_GRAPH = 5,
kINVALID_NODE = 6,
kUNSUPPORTED_GRAPH = 7,
kUNSUPPORTED_NODE = 8
};
template <>
inline int32_t EnumMax<ErrorCode>()
{
return 9;
}
/** \class IParserError
*
* \brief an object containing information about an error
*/
class IParserError
{
public:
/** \brief the error code
*/
virtual ErrorCode code() const = 0;
/** \brief description of the error
*/
virtual const char* desc() const = 0;
/** \brief source file in which the error occurred
*/
virtual const char* file() const = 0;
/** \brief source line at which the error occurred
*/
virtual int line() const = 0;
/** \brief source function in which the error occurred
*/
virtual const char* func() const = 0;
/** \brief index of the ONNX model node in which the error occurred
*/
virtual int node() const = 0;
protected:
virtual ~IParserError() {}
};
/** \class IParser
*
* \brief an object for parsing ONNX models into a TensorRT network definition
*/
class IParser
{
public:
/** \brief Parse a serialized ONNX model into the TensorRT network.
* This method has very limited diagnostics. If parsing the serialized model
* fails for any reason (e.g. unsupported IR version, unsupported opset, etc.)
* it the user responsibility to intercept and report the error.
* To obtain a better diagnostic, use the parseFromFile method below.
*
* \param serialized_onnx_model Pointer to the serialized ONNX model
* \param serialized_onnx_model_size Size of the serialized ONNX model
* in bytes
* \param model_path Absolute path to the model file for loading external weights if required
* \return true if the model was parsed successfully
* \see getNbErrors() getError()
*/
virtual bool parse(void const* serialized_onnx_model,
size_t serialized_onnx_model_size,
const char* model_path = nullptr)
= 0;
/** \brief Parse an onnx model file, which can be a binary protobuf or a text onnx model
* calls parse method inside.
*
* \param File name
* \param Verbosity Level
*
* \return true if the model was parsed successfully
*
*/
virtual bool parseFromFile(const char* onnxModelFile, int verbosity) = 0;
/** \brief Check whether TensorRT supports a particular ONNX model
*
* \param serialized_onnx_model Pointer to the serialized ONNX model
* \param serialized_onnx_model_size Size of the serialized ONNX model
* in bytes
* \param sub_graph_collection Container to hold supported subgraphs
* \param model_path Absolute path to the model file for loading external weights if required
* \return true if the model is supported
*/
virtual bool supportsModel(void const* serialized_onnx_model,
size_t serialized_onnx_model_size,
SubGraphCollection_t& sub_graph_collection,
const char* model_path = nullptr)
= 0;
/** \brief Parse a serialized ONNX model into the TensorRT network
* with consideration of user provided weights
*
* \param serialized_onnx_model Pointer to the serialized ONNX model
* \param serialized_onnx_model_size Size of the serialized ONNX model
* in bytes
* \return true if the model was parsed successfully
* \see getNbErrors() getError()
*/
virtual bool parseWithWeightDescriptors(
void const* serialized_onnx_model, size_t serialized_onnx_model_size)
= 0;
/** \brief Returns whether the specified operator may be supported by the
* parser.
*
* Note that a result of true does not guarantee that the operator will be
* supported in all cases (i.e., this function may return false-positives).
*
* \param op_name The name of the ONNX operator to check for support
*/
virtual bool supportsOperator(const char* op_name) const = 0;
/** \brief destroy this object
*
* \warning deprecated and planned on being removed in TensorRT 10.0
*/
TRT_DEPRECATED virtual void destroy() = 0;
/** \brief Get the number of errors that occurred during prior calls to
* \p parse
*
* \see getError() clearErrors() IParserError
*/
virtual int getNbErrors() const = 0;
/** \brief Get an error that occurred during prior calls to \p parse
*
* \see getNbErrors() clearErrors() IParserError
*/
virtual IParserError const* getError(int index) const = 0;
/** \brief Clear errors from prior calls to \p parse
*
* \see getNbErrors() getError() IParserError
*/
virtual void clearErrors() = 0;
virtual ~IParser() noexcept = default;
};
} // namespace nvonnxparser
extern "C" TENSORRTAPI void* createNvOnnxParser_INTERNAL(void* network, void* logger, int version);
extern "C" TENSORRTAPI int getNvOnnxParserVersion();
namespace nvonnxparser
{
namespace
{
/** \brief Create a new parser object
*
* \param network The network definition that the parser will write to
* \param logger The logger to use
* \return a new parser object or NULL if an error occurred
*
* Any input dimensions that are constant should not be changed after parsing,
* because correctness of the translation may rely on those constants.
* Changing a dynamic input dimension, i.e. one that translates to -1 in
* TensorRT, to a constant is okay if the constant is consistent with the model.
*
* \see IParser
*/
inline IParser* createParser(nvinfer1::INetworkDefinition& network, nvinfer1::ILogger& logger)
{
return static_cast<IParser*>(createNvOnnxParser_INTERNAL(&network, &logger, NV_ONNX_PARSER_VERSION));
}
} // namespace
} // namespace nvonnxparser
#endif // NV_ONNX_PARSER_H