forked from NVIDIA/TensorRT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathqkvToContextPlugin.h
464 lines (360 loc) · 16.3 KB
/
qkvToContextPlugin.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Need 10.1 for cublasGemmStridedBatchedEx
#include <cuda.h>
#if CUDA_VERSION >= 10010
#ifndef TRT_QKV_TO_CONTEXT_PLUGIN_H
#define TRT_QKV_TO_CONTEXT_PLUGIN_H
#include "NvInferPlugin.h"
#include "cublas_v2.h"
#include "zeroPadding2d.h"
#include <string>
#include <vector>
namespace nvinfer1
{
namespace plugin
{
namespace bert
{
// Multi Head Attention runner
class MHARunner
{
public:
MHARunner(const nvinfer1::DataType type, const int32_t numHeads, const int32_t headSize)
: mType(type)
, mS(0)
, mB(0)
, mOmatSize(0)
, mNumMats(0)
, mNumHeads(numHeads)
, mHeadSize(headSize)
, mWordSize(getElementSize(type))
, mLdQKV(0)
, mStrideQKV(0)
, mLdOut(0)
, mStrideOut(0)
, mRsqrtHeadSize(1.F / sqrtf(headSize))
{
}
virtual ~MHARunner() = default;
virtual void setup(const int32_t S, const int32_t B)
{
PLUGIN_ASSERT(S);
PLUGIN_ASSERT(B);
mB = B;
mS = S;
mLdQKV = 3 * B * mNumHeads * mHeadSize;
mStrideQKV = 3 * mHeadSize;
mLdOut = B * mNumHeads * mHeadSize;
mStrideOut = mHeadSize;
mOmatSize = S * S;
mNumMats = B * mNumHeads;
}
virtual void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc,
void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream)
= 0;
virtual void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream)
= 0;
virtual size_t getSerializationSize() const noexcept;
virtual void serialize(void* buffer) const noexcept;
virtual void deserialize(void const* data, size_t length);
virtual size_t getWorkspaceSize() const = 0;
virtual bool isValid(int32_t s) const = 0;
protected:
nvinfer1::DataType mType;
int32_t mS;
int32_t mB;
int32_t mOmatSize;
int32_t mNumMats;
int32_t mNumHeads;
int32_t mHeadSize;
int32_t mWordSize;
int32_t mLdQKV;
int32_t mStrideQKV;
int32_t mLdOut;
int32_t mStrideOut;
float mRsqrtHeadSize;
};
std::pair<int32_t, int32_t> tuneBatchedGemm(
const int32_t B, const int32_t S, const int32_t numHeads, const int32_t headSize);
template <typename T>
int32_t computeScaledSoftmax(cudaStream_t stream, const int32_t ld, const int32_t B, const int32_t N,
float const rsqrtHeadSize, T const* input, T* output);
template <typename T>
int32_t computeMaskedScaledSoftmax(cudaStream_t stream, const int32_t ld, const int32_t B, const int32_t N,
float const rsqrtHeadSize, int32_t const* maskIdx, T const* input, T* output);
// One of the preferred ways of making TensorRT to be able to see
// our custom layer requires extending IPluginV2 and IPluginCreator classes.
// For requirements for overriden functions, check TensorRT API docs.
class QKVToContextPluginDynamic : public nvinfer1::IPluginV2DynamicExt
{
public:
QKVToContextPluginDynamic(const std::string name, const nvinfer1::DataType type, const int32_t hiddenSize,
const int32_t numHeads, float const dqProbs, bool hasImask = false);
QKVToContextPluginDynamic(const std::string name, void const* data, size_t length);
// It doesn't make sense to make QKVToContextPluginDynamic without arguments, so we
// delete default constructor.
QKVToContextPluginDynamic() = delete;
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs,
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
bool supportsFormatCombination(
int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override;
void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override;
int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(
int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override;
// IPluginV2 Methods
char const* getPluginType() const noexcept override;
char const* getPluginVersion() const noexcept override;
int32_t getNbOutputs() const noexcept override;
int32_t initialize() noexcept override;
void terminate() noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
void destroy() noexcept override;
void setPluginNamespace(char const* pluginNamespace) noexcept override;
char const* getPluginNamespace() const noexcept override;
protected:
void createMHARunner();
private:
const std::string mLayerName;
std::string mNamespace;
// used for sequence len 128, 384, precision int8
// used for sequence len 64, 96, 128, 384, precision fp16
std::unique_ptr<MHARunner> fusedDispatcher;
// used for other sequence, precision fp32 and fp16
std::unique_ptr<MHARunner> unfusedDispatcher;
int32_t mS;
int32_t mB;
int32_t mSM;
int32_t mHeadSize;
int32_t mHiddenSize;
int32_t mNumHeads;
bool mHasImask;
nvinfer1::DataType mType;
float mDqProbs;
using IPluginV2::getOutputDimensions;
using IPluginV2::getWorkspaceSize;
using IPluginV2::enqueue;
using IPluginV2Ext::configurePlugin;
};
class QKVToContextPluginDynamicCreator : public nvinfer1::IPluginCreator
{
public:
QKVToContextPluginDynamicCreator();
char const* getPluginName() const noexcept override;
char const* getPluginVersion() const noexcept override;
nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
nvinfer1::IPluginV2* deserializePlugin(
char const* name, void const* serialData, size_t serialLength) noexcept override;
void setPluginNamespace(char const* pluginNamespace) noexcept override;
char const* getPluginNamespace() const noexcept override;
private:
static nvinfer1::PluginFieldCollection mFC;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
std::string mNamespace;
};
class QKVToContextVarSeqlenPlugin : public nvinfer1::IPluginV2DynamicExt
{
public:
QKVToContextVarSeqlenPlugin(std::string const name, nvinfer1::DataType const type, int32_t const hiddenSize,
int32_t const numHeads, float const dqProbs, bool hasImask = false, bool varSeqlen = false,
bool const useInt8ScaleMax = true);
QKVToContextVarSeqlenPlugin(const std::string name, void const* data, size_t length);
// It doesn't make sense to make QKVToContextVarSeqlenPlugin without arguments, so we
// delete default constructor.
QKVToContextVarSeqlenPlugin() = delete;
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs,
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
bool supportsFormatCombination(
int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override;
void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override;
int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(
int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override;
// IPluginV2 Methods
char const* getPluginType() const noexcept override;
char const* getPluginVersion() const noexcept override;
int32_t getNbOutputs() const noexcept override;
int32_t initialize() noexcept override;
void terminate() noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
void destroy() noexcept override;
void setPluginNamespace(char const* pluginNamespace) noexcept override;
char const* getPluginNamespace() const noexcept override;
protected:
void createMHARunner();
private:
const std::string mLayerName;
std::string mNamespace;
std::unique_ptr<MHARunner> dispatcher;
std::unique_ptr<QkvPaddingRunner> patcher;
int32_t mS;
int32_t mB;
int32_t mSM;
int32_t mHeadSize;
int32_t mHiddenSize;
int32_t mNumHeads;
bool mHasImask;
nvinfer1::DataType mType;
float mDqProbs;
int32_t mHdim;
bool mUseVarSeqlen;
bool mUseInt8ScaleMax{true};
using IPluginV2::getOutputDimensions;
using IPluginV2::getWorkspaceSize;
using IPluginV2::enqueue;
using IPluginV2Ext::configurePlugin;
};
class QKVToContextVarSeqlenPluginCreator : public nvinfer1::IPluginCreator
{
public:
QKVToContextVarSeqlenPluginCreator();
char const* getPluginName() const noexcept override;
char const* getPluginVersion() const noexcept override;
nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
nvinfer1::IPluginV2* deserializePlugin(
char const* name, void const* serialData, size_t serialLength) noexcept override;
void setPluginNamespace(char const* pluginNamespace) noexcept override;
char const* getPluginNamespace() const noexcept override;
private:
static nvinfer1::PluginFieldCollection mFC;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
std::string mNamespace;
};
class UnfusedMHARunner : public MHARunner
{
public:
UnfusedMHARunner(
const nvinfer1::DataType type, const int32_t numHeads, const int32_t headSize, const int32_t smVersion);
virtual ~UnfusedMHARunner();
virtual void setup(const int32_t S, const int32_t B) override;
void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc,
void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream) override;
void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) override;
size_t getWorkspaceSize() const override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
void deserialize(void const* data, size_t length) override;
bool isValid(int32_t s) const override;
private:
bool mIsBestAlgoFound;
int32_t mAlgoBatchedEx1;
int32_t mAlgoBatchedEx2;
cublasHandle_t mCublas;
int32_t mSm;
};
class FusedMHARunnerFP16 : public MHARunner
{
public:
FusedMHARunnerFP16(const int32_t numHeads, const int32_t headSize, const int32_t sm);
~FusedMHARunnerFP16() = default; // for pimpl
virtual void setup(const int32_t S, const int32_t B) override;
void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc,
void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream) override;
void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) override;
size_t getWorkspaceSize() const override;
void deserialize(void const* data, size_t length) override;
bool isValid(int32_t s) const override;
private:
int32_t mSm;
class mhaImpl;
std::unique_ptr<mhaImpl> pimpl;
};
class FusedMHARunnerInt8 : public MHARunner
{
public:
FusedMHARunnerInt8(const int32_t numHeads, const int32_t headSize, const int32_t sm, float const dqProbs);
~FusedMHARunnerInt8() = default; // for pimpl
virtual void setup(const int32_t S, const int32_t B) override;
void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc,
void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream) override;
void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) override;
size_t getWorkspaceSize() const override;
void deserialize(void const* data, size_t length) override;
bool isValid(int32_t s) const override;
private:
float mDqProbs;
int32_t mSm;
class mhaImpl;
std::unique_ptr<mhaImpl> pimpl;
};
class FusedMHARunnerFP16v2 : public MHARunner
{
public:
FusedMHARunnerFP16v2(const int32_t numHeads, const int32_t headSize, const int32_t sm);
~FusedMHARunnerFP16v2() = default; // for pimpl
virtual void setup(const int32_t S, const int32_t B) override;
void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc,
void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream) override;
void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) override;
size_t getWorkspaceSize() const override;
void deserialize(void const* data, size_t length) override;
bool isValid(int32_t s) const override;
private:
int32_t mSm;
class mhaImpl;
std::unique_ptr<mhaImpl> pimpl;
};
class FusedMHARunnerInt8v2 : public MHARunner
{
public:
FusedMHARunnerInt8v2(int32_t const numHeads, int32_t const headSize, int32_t const sm, float const dqProbs,
bool const useInt8ScaleMax);
~FusedMHARunnerInt8v2() = default; // for pimpl
virtual void setup(const int32_t S, const int32_t B) override;
void run(nvinfer1::PluginTensorDesc const& inputDesc, nvinfer1::PluginTensorDesc const& outputDesc,
void const* qkvPtr, void const* maskPtr, void* output, void* workspace, cudaStream_t stream) override;
void run(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) override;
size_t getWorkspaceSize() const override;
void deserialize(void const* data, size_t length) override;
bool isValid(int32_t s) const override;
private:
float mDqProbs;
int32_t mSm;
class mhaImpl;
std::unique_ptr<mhaImpl> pimpl;
bool mUseInt8ScaleMax{true};
};
} // namespace bert
} // namespace plugin
} // namespace nvinfer1
#endif // TRT_QKV_TO_CONTEXT_PLUGIN_H
#endif // CUDA_VERSION >= 10010