Skip to content

Commit d24ec24

Browse files
authored
[infer] Fix per_token_group_quant and support per_tensor_quant (#10359)
* support per_tensor_quant and per_token_group_quant * check * add per_token_group_quant, per_tensor_quant_fp8 into pybind_ops_list
1 parent 737a679 commit d24ec24

11 files changed

+684
-216
lines changed

csrc/gpu/cpp_extensions.cu

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,12 +220,14 @@ paddle::Tensor RebuildPaddingV2Func(const paddle::Tensor& tmp_out, // [token_num
220220
const paddle::optional<paddle::Tensor>& output_padding_offset,
221221
int max_input_length);
222222

223-
std::vector<paddle::Tensor> GroupQuant(const paddle::Tensor& x,
223+
std::vector<paddle::Tensor> PerTokenGroupQuant(const paddle::Tensor& x,
224224
const int group_size,
225225
const bool transpose_scale,
226226
const float quant_max_bound,
227227
const float quant_min_bound);
228228

229+
std::vector<paddle::Tensor> PerTensorQuantFp8(const paddle::Tensor& x, const paddle::optional<paddle::Tensor>& scale);
230+
229231
std::vector<paddle::Tensor> GetPaddingOffsetV2(const paddle::Tensor& input_ids,
230232
const paddle::Tensor& cum_offsets,
231233
const paddle::Tensor& token_num,
@@ -295,7 +297,8 @@ PYBIND11_MODULE(paddlenlp_ops, m) {
295297
m.def("f_set_preids_token_penalty_multi_scores", &SetPreidsTokenPenaltyMultiScores, "SetPreidsTokenPenaltyMultiScores");
296298
m.def("f_update_inputs_v2", &UpdateInputesV2, "UpdateInputesV2");
297299
m.def("f_rebuild_padding_v2", &RebuildPaddingV2Func, "RebuildPaddingV2Func");
298-
m.def("f_group_quant", &GroupQuant, "GroupQuant");
300+
m.def("f_per_token_group_quant", &PerTokenGroupQuant, "PerTokenGroupQuant");
301+
m.def("f_per_tensor_quant_fp8", &PerTensorQuantFp8, "PerTensorQuantFp8");
299302
m.def("f_get_padding_offset_v2", &GetPaddingOffsetV2, "GetPaddingOffsetV2");
300303
m.def("f_save_output", &SaveOutMmsg, "SaveOutMmsg");
301304
m.def("f_get_output", &GetOutput, "GetOutput");
@@ -324,7 +327,8 @@ PYBIND11_MODULE(paddlenlp_ops_80, m) {
324327
m.def("f_set_preids_token_penalty_multi_scores", &SetPreidsTokenPenaltyMultiScores, "SetPreidsTokenPenaltyMultiScores");
325328
m.def("f_update_inputs_v2", &UpdateInputesV2, "UpdateInputesV2");
326329
m.def("f_rebuild_padding_v2", &RebuildPaddingV2Func, "RebuildPaddingV2Func");
327-
m.def("f_group_quant", &GroupQuant, "GroupQuant");
330+
m.def("f_per_token_group_quant", &PerTokenGroupQuant, "PerTokenGroupQuant");
331+
m.def("f_per_tensor_quant_fp8", &PerTensorQuantFp8, "PerTensorQuantFp8");
328332
m.def("f_get_padding_offset_v2", &GetPaddingOffsetV2, "GetPaddingOffsetV2");
329333
m.def("f_save_output", &SaveOutMmsg, "SaveOutMmsg");
330334
m.def("f_get_output", &GetOutput, "GetOutput");
@@ -352,7 +356,8 @@ PYBIND11_MODULE(paddlenlp_ops_90, m) {
352356
m.def("f_set_preids_token_penalty_multi_scores", &SetPreidsTokenPenaltyMultiScores, "SetPreidsTokenPenaltyMultiScores");
353357
m.def("f_update_inputs_v2", &UpdateInputesV2, "UpdateInputesV2");
354358
m.def("f_rebuild_padding_v2", &RebuildPaddingV2Func, "RebuildPaddingV2Func");
355-
m.def("f_group_quant", &GroupQuant, "GroupQuant");
359+
m.def("f_per_token_group_quant", &PerTokenGroupQuant, "PerTokenGroupQuant");
360+
m.def("f_per_tensor_quant_fp8", &PerTensorQuantFp8, "PerTensorQuantFp8");
356361
m.def("f_get_padding_offset_v2", &GetPaddingOffsetV2, "GetPaddingOffsetV2");
357362
m.def("f_save_output", &SaveOutMmsg, "SaveOutMmsg");
358363
m.def("f_get_output", &GetOutput, "GetOutput");

csrc/gpu/group_quant.cu

Lines changed: 0 additions & 201 deletions
This file was deleted.

csrc/gpu/helper.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,13 @@ class PDTraits<paddle::DataType::FLOAT8_E4M3FN> {
162162
typedef paddle::float8_e4m3fn data_t;
163163
};
164164

165+
template <>
166+
class PDTraits<paddle::DataType::INT8> {
167+
public:
168+
typedef int8_t DataType;
169+
typedef int8_t data_t;
170+
};
171+
165172
template <typename T, int Size>
166173
struct alignas(sizeof(T) * Size) AlignedVector {
167174
T val[Size];
@@ -245,3 +252,35 @@ inline bool GetMlaUseTensorcore() {
245252
const bool mla_use_tensorcore = flags_mla_use_tensorcore && enable_mla_tensorcore;
246253
return mla_use_tensorcore;
247254
}
255+
256+
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
257+
float old;
258+
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
259+
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
260+
return old;
261+
}
262+
263+
__device__ __forceinline__ float warpReduceMax(float max_value) {
264+
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
265+
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
266+
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
267+
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
268+
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
269+
return max_value;
270+
}
271+
272+
__device__ __forceinline__ float blockReduceMax(float max_value) {
273+
static __shared__ float warpLevelMaxs[32];
274+
const int laneId = threadIdx.x & 0x1f;;
275+
const int warpId = threadIdx.x >> 5;
276+
277+
max_value = warpReduceMax(max_value);
278+
279+
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
280+
__syncthreads();
281+
282+
max_value = (threadIdx.x < blockDim.x / 32) ? warpLevelMaxs[laneId] : 0;
283+
if (warpId == 0) max_value = warpReduceMax(max_value);
284+
285+
return max_value;
286+
}

0 commit comments

Comments
 (0)