Skip to content

Commit

Permalink
[CPU]Unify u8/u4 dequant kernel with template arg
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangYiIntel committed Jan 5, 2025
1 parent 94522a2 commit 7e6ffa2
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 175 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ void attn_quant_u8(const float* src, uint8_t* dst, size_t n, float& scale, float
}

void attn_dequant_u8(const uint8_t* src, float* dst, size_t n, float scale, float zp) {
attn_dequant_u8_kernel(src, dst, n, scale, zp);
attn_dequant_kernel<float, ov::element::u8>(src, dst, n, scale, zp);
}

} // namespace XARCH
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ namespace Extensions {
namespace Cpu {
namespace XARCH {

template <typename TDST>
void attn_dequant_u8_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) {
template <typename TDST, ov::element::Type_t SRC_PREC, typename std::enable_if<SRC_PREC == ov::element::u8, bool>::type = true>
void attn_dequant_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) {
size_t i = 0;
// loadu_si128/epi64 does not support const qualifier
uint8_t* src_nc = const_cast<uint8_t*>(src);
Expand Down Expand Up @@ -52,8 +52,8 @@ void attn_dequant_u8_kernel(const uint8_t* src, TDST* dst, size_t n, float scale
}
}

template <typename TDST>
void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) {
template <typename TDST, ov::element::Type_t SRC_PREC, typename std::enable_if<SRC_PREC == ov::element::u4, bool>::type = true>
void attn_dequant_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) {
// 2 4bit data form a byte
/* 0,1|2,3|4,5|6,7
/ \
Expand Down Expand Up @@ -134,86 +134,6 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale
}
}

template <typename TDST>
void attn_dequant_s4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale) {
// 2 4bit data form a byte
/* 0,1|2,3|4,5|6,7
/ \
0,2,4,6|1,3,5,7
|
permute
|
0,1,2,3,4,5,6,7
*/
size_t i = 0;
uint8_t* src_nc = const_cast<uint8_t*>(src);
#if defined(HAVE_AVX512F)
for (; i + vec_len_f32_avx512 * 2 <= n; i += vec_len_f32_avx512 * 2) {
auto v_scale = _mm512_set1_ps(scale);
auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(src_nc + i / 2));
// cvt to f32
auto v_i32 = _mm512_cvtepi8_epi32(data);

auto v_256_low_half = _mm512_srai_epi32(v_i32, 4);
auto v_f32_low_half = _mm512_cvtepi32_ps(v_256_low_half);
auto v_256_high_half = _mm512_slli_epi32(v_i32, 28);
v_256_high_half = _mm512_srai_epi32(v_256_high_half, 28);
auto v_f32_high_half = _mm512_cvtepi32_ps(v_256_high_half);
// q * scale
v_f32_low_half = _mm512_mul_ps(v_f32_low_half, v_scale);
v_f32_high_half = _mm512_mul_ps(v_f32_high_half, v_scale);

__m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0);
__m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8);
__m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half);
__m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half);
mm512_uni_storeu_ps(dst + i, first_half);
mm512_uni_storeu_ps(dst + i + vec_len_f32_avx512, second_half);
}

#elif defined(HAVE_AVX2)
for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) {
auto v256_scale = _mm256_set1_ps(scale);
auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(src_nc + i / 2));

auto v_i32 = _mm256_cvtepi8_epi32(data);
auto v_256_low_half = _mm256_srai_epi32(v_i32, 4);
auto v_f32_low_half = _mm256_cvtepi32_ps(v_256_low_half);

auto v_256_high_half = _mm256_slli_epi32(v_i32, 28);
v_256_high_half = _mm256_srai_epi32(v_256_high_half, 28);
auto v_f32_high_half = _mm256_cvtepi32_ps(v_256_high_half);

// q * scale
v_f32_low_half = _mm256_mul_ps(v_f32_low_half, v256_scale);
v_f32_high_half = _mm256_mul_ps(v_f32_high_half, v256_scale);

// 0,2,4,6,8,10,12,14 | 1,3,5,7,9,11,13,15
// _mm256_permute2f128_ps
// 0,2,4,6,1,3,5,7 | 8,10,12,14,9,11,13,15
// _mm256_permutevar8x32_ps
// 0,1,2,3,4,5,6,7 | 8,9,10,11,12,13,14,15
__m256 first_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x20);
auto idx1 = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
first_half = _mm256_permutevar8x32_ps(first_half, idx1);
__m256 second_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x31);
second_half = _mm256_permutevar8x32_ps(second_half, idx1);
mm256_uni_storeu_ps(dst + i, first_half);
mm256_uni_storeu_ps(dst + i + vec_len_f32_avx2, second_half);
}
#endif
auto extract_half_byte = [&](uint8_t val, bool high_half) -> int8_t {
uint8_t shift = high_half ? 0 : 4;
return static_cast<float>((val >> shift) & 0x000F);
};
for (; i < n; ++i) {
float tmp = extract_half_byte(src_nc[i / 2], (uint8_t)(i % 2));
tmp = tmp > 8 ? (tmp - 16) : tmp;
tmp = tmp * scale;
dst[i] = tmp;
}
}

} // namespace XARCH
} // namespace Cpu
} // namespace Extensions
Expand Down
110 changes: 20 additions & 90 deletions src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,11 @@ void transpose_16NxK(TDST* dst,
size_t dst_offset = 0;
while (dst_offset < K) {
auto f = reinterpret_cast<float*>(s + src_offset);
attn_dequant_u8_kernel(s + src_offset + sizeof(float) * 2, t + dst_offset, group_size, f[0], f[1]);
attn_dequant_kernel<TDST, SRC_PREC>(s + src_offset + sizeof(float) * 2,
t + dst_offset,
group_size,
f[0],
f[1]);
src_offset += group_size + sizeof(float) * 2;
dst_offset += group_size;
}
Expand Down Expand Up @@ -958,71 +962,25 @@ static inline void dequant(float* dst, ov::float16* src, const size_t N, const s

template <typename TDST,
ov::element::Type_t SRC_PREC,
typename std::enable_if<SRC_PREC == ov::element::u8, bool>::type = true>
void dequant(TDST* dst, uint8_t* src, const size_t N, const size_t K, const size_t group_size) {
// The layout for per token per head:
// |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized
// feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float)
auto s = src;
const size_t params_offset = sizeof(float) * 2;
const size_t src_stride = K / group_size * (group_size + params_offset);

for (size_t n = 0; n < N; n++) {
size_t group_offset = 0;
size_t dst_offset = 0;
while (dst_offset < K) {
auto f = reinterpret_cast<float*>(s + group_offset);
attn_dequant_u8_kernel(s + group_offset + params_offset, dst + dst_offset, group_size, f[0], f[1]);
group_offset += group_size + params_offset;
dst_offset += group_size;
}
s += src_stride;
dst += K;
}
}

template <typename TDST,
ov::element::Type_t SRC_PREC,
typename std::enable_if<SRC_PREC == ov::element::u4, bool>::type = true>
typename std::enable_if<SRC_PREC == ov::element::u4 || SRC_PREC == ov::element::u8, bool>::type = true>
void dequant(TDST* dst, uint8_t* src, const size_t N, const size_t K, const size_t group_size) {
// The layout for per token per head:
// |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized
// feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float)
auto s = src;
const size_t params_offset = sizeof(float) * 2;
const size_t sub_byte_mulitplier = 2;

for (size_t n = 0; n < N; n++) {
size_t src_offset = 0;
size_t dst_offset = 0;
while (dst_offset < K) {
auto f = reinterpret_cast<float*>(s + src_offset);
attn_dequant_u4_kernel(s + src_offset + params_offset, dst + dst_offset, group_size, f[0], f[1]);
src_offset += group_size / sub_byte_mulitplier + params_offset;
dst_offset += group_size;
}
s += src_offset;
dst += K;
}
}

template <typename TDST,
ov::element::Type_t SRC_PREC,
typename std::enable_if<SRC_PREC == ov::element::i4, bool>::type = true>
void dequant(TDST* dst, uint8_t* src, const size_t N, const size_t K, const size_t group_size) {
// The layout for per token per head:
// |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized
// feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float)
auto s = src;
const size_t params_offset = sizeof(float);
const size_t sub_byte_mulitplier = 2;
const size_t sub_byte_mulitplier = get_sub_byte_multiplier(SRC_PREC);

for (size_t n = 0; n < N; n++) {
size_t src_offset = 0;
size_t dst_offset = 0;
while (dst_offset < K) {
auto f = reinterpret_cast<float*>(s + src_offset);
attn_dequant_s4_kernel(s + src_offset + params_offset, dst + dst_offset, group_size, f[0]);
attn_dequant_kernel<TDST, SRC_PREC>(s + src_offset + params_offset,
dst + dst_offset,
group_size,
f[0],
f[1]);
src_offset += group_size / sub_byte_mulitplier + params_offset;
dst_offset += group_size;
}
Expand Down Expand Up @@ -1132,40 +1090,8 @@ static void pack_32NxK(TDST* dst,

template <typename TDST,
ov::element::Type_t SRC_PREC,
typename std::enable_if<precision_of<TDST>::value != ov::element::f32 && SRC_PREC == ov::element::u8,
bool>::type = true>
static void pack_32NxK(TDST* dst,
void* src,
TDST* tmp,
const size_t N,
const size_t K,
const size_t dst_stride,
const size_t src_stride,
const size_t group_size) {
// The layout for per token per head:
// |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized
// feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float)
auto s = reinterpret_cast<typename ov::element_type_traits<SRC_PREC>::value_type*>(src);
auto t = tmp;
// if group_size not set, the whole row is used as a group
for (size_t n = 0; n < N; n++) {
size_t src_offset = 0;
size_t dst_offset = 0;
while (dst_offset < K) {
auto f = reinterpret_cast<float*>(s + src_offset);
attn_dequant_u8_kernel(s + src_offset + sizeof(float) * 2, t + dst_offset, group_size, f[0], f[1]);
src_offset += group_size + sizeof(float) * 2;
dst_offset += group_size;
}
s += src_offset;
t += src_stride;
}
pack_32NxK<TDST, precision_of<TDST>::value>(dst, tmp, reinterpret_cast<TDST*>(0), N, K, dst_stride, src_stride, 0);
}

template <typename TDST,
ov::element::Type_t SRC_PREC,
typename std::enable_if<precision_of<TDST>::value != ov::element::f32 && (SRC_PREC == ov::element::u4),
typename std::enable_if<precision_of<TDST>::value != ov::element::f32 &&
(SRC_PREC == ov::element::u4 || SRC_PREC == ov::element::u8),
bool>::type = true>
static void pack_32NxK(TDST* dst,
void* src,
Expand All @@ -1181,13 +1107,17 @@ static void pack_32NxK(TDST* dst,
auto s = reinterpret_cast<uint8_t*>(src);
auto t = tmp;
// if group_size not set, the whole row is used as a group
const size_t sub_byte_mulitplier = 2;
const size_t sub_byte_mulitplier = get_sub_byte_multiplier(SRC_PREC);
for (size_t n = 0; n < N; n++) {
size_t src_offset = 0;
size_t dst_offset = 0;
while (dst_offset < K) {
auto f = reinterpret_cast<float*>(s + src_offset);
attn_dequant_u4_kernel(s + (src_offset + sizeof(float) * 2), t + dst_offset, group_size, f[0], f[1]);
attn_dequant_kernel<TDST, SRC_PREC>(s + (src_offset + sizeof(float) * 2),
t + dst_offset,
group_size,
f[0],
f[1]);
src_offset += group_size / sub_byte_mulitplier + sizeof(float) * 2;
dst_offset += group_size;
}
Expand Down

0 comments on commit 7e6ffa2

Please sign in to comment.