From c362399b8bbf7c8346dd8b1474b3e5980c4269af Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Fri, 3 Jan 2025 13:40:59 +0800 Subject: [PATCH] [CPU]apply review comments Signed-off-by: Zhang Yi --- src/plugins/intel_cpu/src/config.cpp | 57 +++--- .../nodes/kernels/scaled_attn/attn_quant.cpp | 163 ++++-------------- .../intel_cpu/src/nodes/paged_attn.cpp | 9 +- .../ov_executable_network/properties.cpp | 30 ++++ 4 files changed, 98 insertions(+), 161 deletions(-) diff --git a/src/plugins/intel_cpu/src/config.cpp b/src/plugins/intel_cpu/src/config.cpp index 5b0d677d11b54f..b16e270b984fca 100644 --- a/src/plugins/intel_cpu/src/config.cpp +++ b/src/plugins/intel_cpu/src/config.cpp @@ -373,43 +373,42 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { ov::hint::kv_cache_precision.name(), ". Supported values: u8, bf16, f16, f32"); } - } else if (key == ov::hint::key_cache_precision.name() || key == ov::hint::value_cache_precision.name()) { + } else if (key == ov::hint::key_cache_precision.name()) { try { kvCachePrecisionSetExplicitly = true; auto const prec = val.as(); - if (key == ov::hint::key_cache_precision.name()) { - if (one_of(prec, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8)) { - keyCachePrecision = prec; - } else { - OPENVINO_THROW("keyCachePrecision doesn't support value ", prec); - } + if (one_of(prec, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8)) { + keyCachePrecision = prec; } else { - if (one_of(prec, - ov::element::f32, - ov::element::f16, - ov::element::bf16, - ov::element::u8, - ov::element::u4, - ov::element::i4)) { - valueCachePrecision = prec; - } else { - OPENVINO_THROW("valueCachePrecision doesn't support value ", prec); - } + OPENVINO_THROW("keyCachePrecision doesn't support value ", prec); } } catch (ov::Exception&) { - if (key == ov::hint::key_cache_precision.name()) { - OPENVINO_THROW("Wrong value ", - val.as(), - " for property key ", - ov::hint::key_cache_precision.name(), - ". Supported values: u8, bf16, f16, f32"); + OPENVINO_THROW("Wrong value ", + val.as(), + " for property key ", + ov::hint::key_cache_precision.name(), + ". Supported values: u8, bf16, f16, f32"); + } + } else if (key == ov::hint::value_cache_precision.name()) { + try { + kvCachePrecisionSetExplicitly = true; + auto const prec = val.as(); + if (one_of(prec, + ov::element::f32, + ov::element::f16, + ov::element::bf16, + ov::element::u8, + ov::element::u4)) { + valueCachePrecision = prec; } else { - OPENVINO_THROW("Wrong value ", - val.as(), - " for property key ", - ov::hint::value_cache_precision.name(), - ". Supported values: u4, s4, u8, bf16, f16, f32"); + OPENVINO_THROW("valueCachePrecision doesn't support value ", prec); } + } catch (ov::Exception&) { + OPENVINO_THROW("Wrong value ", + val.as(), + " for property key ", + ov::hint::value_cache_precision.name(), + ". Supported values: u4, s4, u8, bf16, f16, f32"); } } else if (key == ov::hint::key_cache_group_size.name() || key == ov::hint::value_cache_group_size.name()) { try { diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index fb6dc8439ac9bf..40ed27bf73ea97 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -218,7 +218,7 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) _mm512_mask_cvtepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined); } #endif -#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) +#if defined(HAVE_AVX2) auto v256_zero = _mm256_set1_epi32(0); auto v256_upper = _mm256_set1_epi32(15); auto v256_scale = _mm256_set1_ps(1 / scale); @@ -273,7 +273,7 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) } template -static void quant_s4(const T* src, void* dst, size_t n, float& scale) { +static void quant_i4(const T* src, void* dst, size_t n, float& scale) { auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t { uint8_t shift = high_half ? 0 : 4; if (high_half) @@ -318,7 +318,7 @@ static void quant_s4(const T* src, void* dst, size_t n, float& scale) { _mm512_mask_cvtepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined); } #endif -#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) +#if defined(HAVE_AVX2) auto v256_lower = _mm256_set1_epi32(-8); auto v256_upper = _mm256_set1_epi32(7); auto v256_scale = _mm256_set1_ps(1 / scale); @@ -372,6 +372,27 @@ static void quant_s4(const T* src, void* dst, size_t n, float& scale) { } } +template ::type = true> +static void quantize(const T* src, uint8_t* dst, size_t n, float* scale_zp) { + quant_u8(src, dst, n, *scale_zp, *(scale_zp + 1)); +} + +template ::type = true> +static void quantize(const T* src, void* dst, size_t n, float* scale_zp) { + quant_u4(src, dst, n, *scale_zp, *(scale_zp + 1)); +} + +template ::type = true> +static void quantize(const T* src, void* dst, size_t n, float* scale_zp) { + quant_i4(src, dst, n, *scale_zp); +} + template static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, @@ -389,10 +410,7 @@ static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, }); } -template ::type = true> +template static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, @@ -402,6 +420,7 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const size_t value_group_size) { size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3], SV = v_src.m_dims[3]; size_t block_size = k_dst.m_dims[2]; + size_t sub_byte_multiplier = 8 / v_dst.get_precision().bitwidth(); parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { auto slot = slot_mapping.ptr(b)[m]; if (slot < 0) @@ -418,76 +437,15 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, h, block_offset, dst_offset)); - quant_u8(k_src.ptr(b, h, m, src_offset), - k_dst.ptr::value_type>(block_number, - h, - block_offset, - dst_offset) + - sizeof(float) + sizeof(float), - key_group_size, - p_k[0], - p_k[1]); - } - - for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; - src_offset += value_group_size, dst_offset += value_group_size + sizeof(float) + sizeof(float)) { - auto p_v = reinterpret_cast( - v_dst.ptr::value_type>(block_number, - h, - block_offset, - dst_offset)); - quant_u8(v_src.ptr(b, h, m, src_offset), - v_dst.ptr::value_type>(block_number, - h, - block_offset, - dst_offset) + - sizeof(float) + sizeof(float), - value_group_size, - p_v[0], - p_v[1]); - } - }); -} - -template ::type = true> -static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, - const ov::intel_cpu::PlainTensor& v_src, - const ov::intel_cpu::PlainTensor& k_dst, - const ov::intel_cpu::PlainTensor& v_dst, - const ov::intel_cpu::PlainTensor& slot_mapping, - const size_t key_group_size, - const size_t value_group_size) { - size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3], SV = v_src.m_dims[3]; - size_t block_size = k_dst.m_dims[2]; - size_t sub_byte_multiplier = 8 / v_dst.get_precision().bitwidth(); - parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { - auto slot = slot_mapping.ptr(b)[m]; - if (slot < 0) - return; - auto block_number = slot / block_size; - auto block_offset = slot % block_size; - // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized - // feature(u8,idx_S)| - for (size_t src_offset = 0, dst_offset = 0; src_offset < S; - src_offset += key_group_size, dst_offset += key_group_size + sizeof(float) + sizeof(float)) { - auto p_k = reinterpret_cast( + quantize( + k_src.ptr(b, h, m, src_offset), k_dst.ptr::value_type>(block_number, h, block_offset, - dst_offset)); - quant_u8(k_src.ptr(b, h, m, src_offset), - k_dst.ptr::value_type>(block_number, - h, - block_offset, - dst_offset) + - sizeof(float) + sizeof(float), - key_group_size, - p_k[0], - p_k[1]); + dst_offset) + + sizeof(float) + sizeof(float), + key_group_size, + p_k); } for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += value_group_size, @@ -499,62 +457,7 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, dst_offset); auto p_v = reinterpret_cast(v_base); uint8_t* v_ptr = v_base + sizeof(float) * 2; - quant_u4(v_src.ptr(b, h, m, src_offset), v_ptr, value_group_size, p_v[0], p_v[1]); - } - }); -} - -template ::type = true> -static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, - const ov::intel_cpu::PlainTensor& v_src, - const ov::intel_cpu::PlainTensor& k_dst, - const ov::intel_cpu::PlainTensor& v_dst, - const ov::intel_cpu::PlainTensor& slot_mapping, - const size_t key_group_size, - const size_t value_group_size) { - size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3], SV = v_src.m_dims[3]; - size_t block_size = k_dst.m_dims[2]; - size_t sub_byte_multiplier = 8 / v_dst.get_precision().bitwidth(); - parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { - auto slot = slot_mapping.ptr(b)[m]; - if (slot < 0) - return; - auto block_number = slot / block_size; - auto block_offset = slot % block_size; - // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized - // feature(u8,idx_S)| - for (size_t src_offset = 0, dst_offset = 0; src_offset < S; - src_offset += key_group_size, dst_offset += key_group_size + sizeof(float) + sizeof(float)) { - auto p_k = reinterpret_cast( - k_dst.ptr::value_type>(block_number, - h, - block_offset, - dst_offset)); - quant_u8(k_src.ptr(b, h, m, src_offset), - k_dst.ptr::value_type>(block_number, - h, - block_offset, - dst_offset) + - sizeof(float) + sizeof(float), - key_group_size, - p_k[0], - p_k[1]); - } - - for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; - src_offset += value_group_size, dst_offset += value_group_size / sub_byte_multiplier + sizeof(float)) { - uint8_t* v_base = reinterpret_cast( - v_dst.m_ptr.get() + - (block_number * v_dst.m_strides[0] + h * v_dst.m_strides[1] + block_offset * v_dst.m_strides[2]) / - sub_byte_multiplier + - dst_offset); - auto p_v = reinterpret_cast(v_base); - uint8_t* v_ptr = v_base + sizeof(float); - quant_s4(v_src.ptr(b, h, m, src_offset), v_ptr, value_group_size, p_v[0]); + quantize(v_src.ptr(b, h, m, src_offset), v_ptr, value_group_size, p_v); } }); } diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index 209183f0060ad3..54aa80e9dff7c0 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -211,8 +211,13 @@ bool PagedAttention::isSupportedOperation(const std::shared_ptr& try { auto vCachePrecision = op->get_input_element_type(PagedAttentionExecutor::ID_VCACHE); auto kCachePrecision = op->get_input_element_type(PagedAttentionExecutor::ID_KCACHE); - if (one_of(vCachePrecision, ov::element::i4, ov::element::u4, ov::element::u8)) { - if (kCachePrecision != ov::element::u8) { + if (one_of(vCachePrecision, + ov::element::u4, + ov::element::u8, + ov::element::f32, + ov::element::f16, + ov::element::bf16)) { + if (!one_of(kCachePrecision, ov::element::u8, ov::element::f16, ov::element::f32, ov::element::bf16)) { errorMessage = "PageAttn key value cache compression doesn't support key cache prec " + kCachePrecision.to_string() + " value cache prec " + vCachePrecision.to_string(); return false; diff --git a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp index 59fd31cdb34303..016648a7e1026f 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp @@ -187,6 +187,36 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckKVCachePrecision) { ASSERT_EQ(kv_cache_precision_value, ov::element::f32); } +TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkFinetuneKVCachePrecision) { + ov::Core core; + + core.set_property(deviceName, ov::hint::key_cache_precision(ov::element::f16)); + core.set_property(deviceName, ov::hint::value_cache_precision(ov::element::u4)); + ov::CompiledModel compiledModel = core.compile_model(model, deviceName); + + auto key_cache_precision_value = ov::element::undefined; + auto value_cache_precision_value = ov::element::undefined; + OV_ASSERT_NO_THROW(key_cache_precision_value = compiledModel.get_property(ov::hint::key_cache_precision)); + OV_ASSERT_NO_THROW(value_cache_precision_value = compiledModel.get_property(ov::hint::value_cache_precision)); + ASSERT_EQ(key_cache_precision_value, ov::element::f16); + ASSERT_EQ(value_cache_precision_value, ov::element::u4); +} + +TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkFinetuneKVCacheGroupSize) { + ov::Core core; + + core.set_property(deviceName, ov::hint::key_cache_group_size(32)); + core.set_property(deviceName, ov::hint::value_cache_group_size(16)); + ov::CompiledModel compiledModel = core.compile_model(model, deviceName); + + auto key_cache_group_size_value = 0; + auto value_cache_group_size_value = 0; + OV_ASSERT_NO_THROW(key_cache_group_size_value = compiledModel.get_property(ov::hint::key_cache_group_size)); + OV_ASSERT_NO_THROW(value_cache_group_size_value = compiledModel.get_property(ov::hint::value_cache_group_size)); + ASSERT_EQ(key_cache_group_size_value, 32); + ASSERT_EQ(value_cache_group_size_value, 16); +} + TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckAccuracyModeDynamicQuantizationGroupSize) { ov::Core core;