@@ -220,12 +220,14 @@ paddle::Tensor RebuildPaddingV2Func(const paddle::Tensor& tmp_out, // [token_num
220
220
const paddle::optional<paddle::Tensor>& output_padding_offset,
221
221
int max_input_length);
222
222
223
- std::vector<paddle::Tensor> GroupQuant (const paddle::Tensor& x,
223
+ std::vector<paddle::Tensor> PerTokenGroupQuant (const paddle::Tensor& x,
224
224
const int group_size,
225
225
const bool transpose_scale,
226
226
const float quant_max_bound,
227
227
const float quant_min_bound);
228
228
229
+ std::vector<paddle::Tensor> PerTensorQuantFp8 (const paddle::Tensor& x, const paddle::optional<paddle::Tensor>& scale);
230
+
229
231
std::vector<paddle::Tensor> GetPaddingOffsetV2 (const paddle::Tensor& input_ids,
230
232
const paddle::Tensor& cum_offsets,
231
233
const paddle::Tensor& token_num,
@@ -295,7 +297,8 @@ PYBIND11_MODULE(paddlenlp_ops, m) {
295
297
m.def (" f_set_preids_token_penalty_multi_scores" , &SetPreidsTokenPenaltyMultiScores, " SetPreidsTokenPenaltyMultiScores" );
296
298
m.def (" f_update_inputs_v2" , &UpdateInputesV2, " UpdateInputesV2" );
297
299
m.def (" f_rebuild_padding_v2" , &RebuildPaddingV2Func, " RebuildPaddingV2Func" );
298
- m.def (" f_group_quant" , &GroupQuant, " GroupQuant" );
300
+ m.def (" f_per_token_group_quant" , &PerTokenGroupQuant, " PerTokenGroupQuant" );
301
+ m.def (" f_per_tensor_quant_fp8" , &PerTensorQuantFp8, " PerTensorQuantFp8" );
299
302
m.def (" f_get_padding_offset_v2" , &GetPaddingOffsetV2, " GetPaddingOffsetV2" );
300
303
m.def (" f_save_output" , &SaveOutMmsg, " SaveOutMmsg" );
301
304
m.def (" f_get_output" , &GetOutput, " GetOutput" );
@@ -324,7 +327,8 @@ PYBIND11_MODULE(paddlenlp_ops_80, m) {
324
327
m.def (" f_set_preids_token_penalty_multi_scores" , &SetPreidsTokenPenaltyMultiScores, " SetPreidsTokenPenaltyMultiScores" );
325
328
m.def (" f_update_inputs_v2" , &UpdateInputesV2, " UpdateInputesV2" );
326
329
m.def (" f_rebuild_padding_v2" , &RebuildPaddingV2Func, " RebuildPaddingV2Func" );
327
- m.def (" f_group_quant" , &GroupQuant, " GroupQuant" );
330
+ m.def (" f_per_token_group_quant" , &PerTokenGroupQuant, " PerTokenGroupQuant" );
331
+ m.def (" f_per_tensor_quant_fp8" , &PerTensorQuantFp8, " PerTensorQuantFp8" );
328
332
m.def (" f_get_padding_offset_v2" , &GetPaddingOffsetV2, " GetPaddingOffsetV2" );
329
333
m.def (" f_save_output" , &SaveOutMmsg, " SaveOutMmsg" );
330
334
m.def (" f_get_output" , &GetOutput, " GetOutput" );
@@ -352,7 +356,8 @@ PYBIND11_MODULE(paddlenlp_ops_90, m) {
352
356
m.def (" f_set_preids_token_penalty_multi_scores" , &SetPreidsTokenPenaltyMultiScores, " SetPreidsTokenPenaltyMultiScores" );
353
357
m.def (" f_update_inputs_v2" , &UpdateInputesV2, " UpdateInputesV2" );
354
358
m.def (" f_rebuild_padding_v2" , &RebuildPaddingV2Func, " RebuildPaddingV2Func" );
355
- m.def (" f_group_quant" , &GroupQuant, " GroupQuant" );
359
+ m.def (" f_per_token_group_quant" , &PerTokenGroupQuant, " PerTokenGroupQuant" );
360
+ m.def (" f_per_tensor_quant_fp8" , &PerTensorQuantFp8, " PerTensorQuantFp8" );
356
361
m.def (" f_get_padding_offset_v2" , &GetPaddingOffsetV2, " GetPaddingOffsetV2" );
357
362
m.def (" f_save_output" , &SaveOutMmsg, " SaveOutMmsg" );
358
363
m.def (" f_get_output" , &GetOutput, " GetOutput" );
0 commit comments