From 7e6ffa24269678f7437711a9cdb7d6f595fc250f Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Sun, 5 Jan 2025 13:23:43 +0800 Subject: [PATCH] [CPU]Unify u8/u4 dequant kernel with template arg --- .../nodes/kernels/scaled_attn/attn_quant.cpp | 2 +- .../kernels/scaled_attn/attn_quant_kernel.hpp | 88 +------------- .../nodes/kernels/scaled_attn/executor_pa.cpp | 110 ++++-------------- 3 files changed, 25 insertions(+), 175 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index e25b204e670218..26282a70fcb512 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -429,7 +429,7 @@ void attn_quant_u8(const float* src, uint8_t* dst, size_t n, float& scale, float } void attn_dequant_u8(const uint8_t* src, float* dst, size_t n, float scale, float zp) { - attn_dequant_u8_kernel(src, dst, n, scale, zp); + attn_dequant_kernel(src, dst, n, scale, zp); } } // namespace XARCH diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index 97a7d53a2efa05..43bab9b69d5efa 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -17,8 +17,8 @@ namespace Extensions { namespace Cpu { namespace XARCH { -template -void attn_dequant_u8_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { +template ::type = true> +void attn_dequant_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { size_t i = 0; // loadu_si128/epi64 does not support const qualifier uint8_t* src_nc = const_cast(src); @@ -52,8 +52,8 @@ void attn_dequant_u8_kernel(const uint8_t* src, TDST* dst, size_t n, float scale } } -template -void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { +template ::type = true> +void attn_dequant_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { // 2 4bit data form a byte /* 0,1|2,3|4,5|6,7 / \ @@ -134,86 +134,6 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale } } -template -void attn_dequant_s4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale) { - // 2 4bit data form a byte - /* 0,1|2,3|4,5|6,7 - / \ - 0,2,4,6|1,3,5,7 - | - permute - | - 0,1,2,3,4,5,6,7 - */ - size_t i = 0; - uint8_t* src_nc = const_cast(src); -#if defined(HAVE_AVX512F) - for (; i + vec_len_f32_avx512 * 2 <= n; i += vec_len_f32_avx512 * 2) { - auto v_scale = _mm512_set1_ps(scale); - auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(src_nc + i / 2)); - // cvt to f32 - auto v_i32 = _mm512_cvtepi8_epi32(data); - - auto v_256_low_half = _mm512_srai_epi32(v_i32, 4); - auto v_f32_low_half = _mm512_cvtepi32_ps(v_256_low_half); - auto v_256_high_half = _mm512_slli_epi32(v_i32, 28); - v_256_high_half = _mm512_srai_epi32(v_256_high_half, 28); - auto v_f32_high_half = _mm512_cvtepi32_ps(v_256_high_half); - // q * scale - v_f32_low_half = _mm512_mul_ps(v_f32_low_half, v_scale); - v_f32_high_half = _mm512_mul_ps(v_f32_high_half, v_scale); - - __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); - __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); - __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); - __m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half); - mm512_uni_storeu_ps(dst + i, first_half); - mm512_uni_storeu_ps(dst + i + vec_len_f32_avx512, second_half); - } - -#elif defined(HAVE_AVX2) - for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { - auto v256_scale = _mm256_set1_ps(scale); - auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(src_nc + i / 2)); - - auto v_i32 = _mm256_cvtepi8_epi32(data); - auto v_256_low_half = _mm256_srai_epi32(v_i32, 4); - auto v_f32_low_half = _mm256_cvtepi32_ps(v_256_low_half); - - auto v_256_high_half = _mm256_slli_epi32(v_i32, 28); - v_256_high_half = _mm256_srai_epi32(v_256_high_half, 28); - auto v_f32_high_half = _mm256_cvtepi32_ps(v_256_high_half); - - // q * scale - v_f32_low_half = _mm256_mul_ps(v_f32_low_half, v256_scale); - v_f32_high_half = _mm256_mul_ps(v_f32_high_half, v256_scale); - - // 0,2,4,6,8,10,12,14 | 1,3,5,7,9,11,13,15 - // _mm256_permute2f128_ps - // 0,2,4,6,1,3,5,7 | 8,10,12,14,9,11,13,15 - // _mm256_permutevar8x32_ps - // 0,1,2,3,4,5,6,7 | 8,9,10,11,12,13,14,15 - __m256 first_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x20); - auto idx1 = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); - first_half = _mm256_permutevar8x32_ps(first_half, idx1); - __m256 second_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x31); - second_half = _mm256_permutevar8x32_ps(second_half, idx1); - mm256_uni_storeu_ps(dst + i, first_half); - mm256_uni_storeu_ps(dst + i + vec_len_f32_avx2, second_half); - } -#endif - auto extract_half_byte = [&](uint8_t val, bool high_half) -> int8_t { - uint8_t shift = high_half ? 0 : 4; - return static_cast((val >> shift) & 0x000F); - }; - for (; i < n; ++i) { - float tmp = extract_half_byte(src_nc[i / 2], (uint8_t)(i % 2)); - tmp = tmp > 8 ? (tmp - 16) : tmp; - tmp = tmp * scale; - dst[i] = tmp; - } -} - } // namespace XARCH } // namespace Cpu } // namespace Extensions diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index bd659cb1b164f7..955e7687ef97b3 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -930,7 +930,11 @@ void transpose_16NxK(TDST* dst, size_t dst_offset = 0; while (dst_offset < K) { auto f = reinterpret_cast(s + src_offset); - attn_dequant_u8_kernel(s + src_offset + sizeof(float) * 2, t + dst_offset, group_size, f[0], f[1]); + attn_dequant_kernel(s + src_offset + sizeof(float) * 2, + t + dst_offset, + group_size, + f[0], + f[1]); src_offset += group_size + sizeof(float) * 2; dst_offset += group_size; } @@ -958,71 +962,25 @@ static inline void dequant(float* dst, ov::float16* src, const size_t N, const s template ::type = true> -void dequant(TDST* dst, uint8_t* src, const size_t N, const size_t K, const size_t group_size) { - // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized - // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) - auto s = src; - const size_t params_offset = sizeof(float) * 2; - const size_t src_stride = K / group_size * (group_size + params_offset); - - for (size_t n = 0; n < N; n++) { - size_t group_offset = 0; - size_t dst_offset = 0; - while (dst_offset < K) { - auto f = reinterpret_cast(s + group_offset); - attn_dequant_u8_kernel(s + group_offset + params_offset, dst + dst_offset, group_size, f[0], f[1]); - group_offset += group_size + params_offset; - dst_offset += group_size; - } - s += src_stride; - dst += K; - } -} - -template ::type = true> + typename std::enable_if::type = true> void dequant(TDST* dst, uint8_t* src, const size_t N, const size_t K, const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = src; const size_t params_offset = sizeof(float) * 2; - const size_t sub_byte_mulitplier = 2; - - for (size_t n = 0; n < N; n++) { - size_t src_offset = 0; - size_t dst_offset = 0; - while (dst_offset < K) { - auto f = reinterpret_cast(s + src_offset); - attn_dequant_u4_kernel(s + src_offset + params_offset, dst + dst_offset, group_size, f[0], f[1]); - src_offset += group_size / sub_byte_mulitplier + params_offset; - dst_offset += group_size; - } - s += src_offset; - dst += K; - } -} - -template ::type = true> -void dequant(TDST* dst, uint8_t* src, const size_t N, const size_t K, const size_t group_size) { - // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized - // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) - auto s = src; - const size_t params_offset = sizeof(float); - const size_t sub_byte_mulitplier = 2; + const size_t sub_byte_mulitplier = get_sub_byte_multiplier(SRC_PREC); for (size_t n = 0; n < N; n++) { size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { auto f = reinterpret_cast(s + src_offset); - attn_dequant_s4_kernel(s + src_offset + params_offset, dst + dst_offset, group_size, f[0]); + attn_dequant_kernel(s + src_offset + params_offset, + dst + dst_offset, + group_size, + f[0], + f[1]); src_offset += group_size / sub_byte_mulitplier + params_offset; dst_offset += group_size; } @@ -1132,40 +1090,8 @@ static void pack_32NxK(TDST* dst, template ::value != ov::element::f32 && SRC_PREC == ov::element::u8, - bool>::type = true> -static void pack_32NxK(TDST* dst, - void* src, - TDST* tmp, - const size_t N, - const size_t K, - const size_t dst_stride, - const size_t src_stride, - const size_t group_size) { - // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized - // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) - auto s = reinterpret_cast::value_type*>(src); - auto t = tmp; - // if group_size not set, the whole row is used as a group - for (size_t n = 0; n < N; n++) { - size_t src_offset = 0; - size_t dst_offset = 0; - while (dst_offset < K) { - auto f = reinterpret_cast(s + src_offset); - attn_dequant_u8_kernel(s + src_offset + sizeof(float) * 2, t + dst_offset, group_size, f[0], f[1]); - src_offset += group_size + sizeof(float) * 2; - dst_offset += group_size; - } - s += src_offset; - t += src_stride; - } - pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride, 0); -} - -template ::value != ov::element::f32 && (SRC_PREC == ov::element::u4), + typename std::enable_if::value != ov::element::f32 && + (SRC_PREC == ov::element::u4 || SRC_PREC == ov::element::u8), bool>::type = true> static void pack_32NxK(TDST* dst, void* src, @@ -1181,13 +1107,17 @@ static void pack_32NxK(TDST* dst, auto s = reinterpret_cast(src); auto t = tmp; // if group_size not set, the whole row is used as a group - const size_t sub_byte_mulitplier = 2; + const size_t sub_byte_mulitplier = get_sub_byte_multiplier(SRC_PREC); for (size_t n = 0; n < N; n++) { size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { auto f = reinterpret_cast(s + src_offset); - attn_dequant_u4_kernel(s + (src_offset + sizeof(float) * 2), t + dst_offset, group_size, f[0], f[1]); + attn_dequant_kernel(s + (src_offset + sizeof(float) * 2), + t + dst_offset, + group_size, + f[0], + f[1]); src_offset += group_size / sub_byte_mulitplier + sizeof(float) * 2; dst_offset += group_size; }