From dea13d3d8f42937ea744e2b05ee6876d8dc6d95e Mon Sep 17 00:00:00 2001 From: dmitrygo Date: Wed, 8 Jan 2025 15:48:52 +0400 Subject: [PATCH] Revert "[CPU]PageAttn with 4bit-quantization (#27992)" This reverts commit b319014698cfed0112e4685dceb077dbcf5ed5ad. --- .../openvino/runtime/properties/__init__.py | 4 - .../pyopenvino/core/properties/properties.cpp | 4 - .../tests/test_runtime/test_properties.py | 12 - .../include/openvino/runtime/properties.hpp | 24 - src/plugins/intel_cpu/src/compiled_model.cpp | 12 - src/plugins/intel_cpu/src/config.cpp | 86 +- src/plugins/intel_cpu/src/config.h | 10 - .../nodes/kernels/scaled_attn/attn_quant.cpp | 221 +--- .../nodes/kernels/scaled_attn/attn_quant.hpp | 4 +- .../kernels/scaled_attn/attn_quant_kernel.hpp | 90 +- .../nodes/kernels/scaled_attn/executor_pa.cpp | 1115 +++++------------ .../nodes/kernels/scaled_attn/executor_pa.hpp | 6 +- .../intel_cpu/src/nodes/paged_attn.cpp | 31 +- .../intel_cpu/src/nodes/scaled_attn.cpp | 44 +- src/plugins/intel_cpu/src/nodes/scaled_attn.h | 7 +- src/plugins/intel_cpu/src/plugin.cpp | 12 - .../ov_executable_network/properties.cpp | 138 +- .../custom/behavior/ov_plugin/properties.cpp | 4 - 18 files changed, 411 insertions(+), 1413 deletions(-) diff --git a/src/bindings/python/src/openvino/runtime/properties/__init__.py b/src/bindings/python/src/openvino/runtime/properties/__init__.py index 76cfefcf05d8f8..511c019be8d969 100644 --- a/src/bindings/python/src/openvino/runtime/properties/__init__.py +++ b/src/bindings/python/src/openvino/runtime/properties/__init__.py @@ -28,10 +28,6 @@ from openvino._pyopenvino.properties import loaded_from_cache from openvino._pyopenvino.properties import cache_encryption_callbacks from openvino._pyopenvino.properties import weights_path -from openvino._pyopenvino.properties import key_cache_precision -from openvino._pyopenvino.properties import value_cache_precision -from openvino._pyopenvino.properties import key_cache_group_size -from openvino._pyopenvino.properties import value_cache_group_size # Submodules from openvino.runtime.properties import hint diff --git a/src/bindings/python/src/pyopenvino/core/properties/properties.cpp b/src/bindings/python/src/pyopenvino/core/properties/properties.cpp index 2e3a2c8d1c9be8..d0e3ddb21644e7 100644 --- a/src/bindings/python/src/pyopenvino/core/properties/properties.cpp +++ b/src/bindings/python/src/pyopenvino/core/properties/properties.cpp @@ -34,10 +34,6 @@ void regmodule_properties(py::module m) { wrap_property_RW(m_properties, ov::force_tbb_terminate, "force_tbb_terminate"); wrap_property_RW(m_properties, ov::enable_mmap, "enable_mmap"); wrap_property_RW(m_properties, ov::weights_path, "weights_path"); - wrap_property_RW(m_properties, ov::key_cache_precision, "key_cache_precision"); - wrap_property_RW(m_properties, ov::value_cache_precision, "value_cache_precision"); - wrap_property_RW(m_properties, ov::key_cache_group_size, "key_cache_group_size"); - wrap_property_RW(m_properties, ov::value_cache_group_size, "value_cache_group_size"); wrap_property_RO(m_properties, ov::supported_properties, "supported_properties"); wrap_property_RO(m_properties, ov::available_devices, "available_devices"); diff --git a/src/bindings/python/tests/test_runtime/test_properties.py b/src/bindings/python/tests/test_runtime/test_properties.py index 15e2d86ead4653..61fb7442987418 100644 --- a/src/bindings/python/tests/test_runtime/test_properties.py +++ b/src/bindings/python/tests/test_runtime/test_properties.py @@ -257,18 +257,6 @@ def test_properties_ro(ov_property_ro, expected_value): "WEIGHTS_PATH", (("./model.bin", "./model.bin"),), ), - ( - props.key_cache_group_size, - "KEY_CACHE_GROUP_SIZE", - ((64, 64),), - ), - ( - props.value_cache_group_size, - "VALUE_CACHE_GROUP_SIZE", - ((64, 64),), - ), - (props.key_cache_precision, "KEY_CACHE_PRECISION", ((Type.f32, Type.f32),)), - (props.value_cache_precision, "VALUE_CACHE_PRECISION", ((Type.f32, Type.f32),)), (hints.inference_precision, "INFERENCE_PRECISION_HINT", ((Type.f32, Type.f32),)), ( hints.model_priority, diff --git a/src/inference/include/openvino/runtime/properties.hpp b/src/inference/include/openvino/runtime/properties.hpp index 5dd3184599260e..28538f0f60e22e 100644 --- a/src/inference/include/openvino/runtime/properties.hpp +++ b/src/inference/include/openvino/runtime/properties.hpp @@ -1301,28 +1301,4 @@ static constexpr Property, PropertyMutability::RO> exec * @note This property is used for weightless caching. Only used when ov::CacheMode Property is set to "OPTIMIZE_SIZE". */ static constexpr Property weights_path{"WEIGHTS_PATH"}; - -/** - * @brief The precision of key cache compression - * @ingroup ov_runtime_cpp_prop_api - */ -static constexpr Property key_cache_precision{"KEY_CACHE_PRECISION"}; - -/** - * @brief The precision of value cache compression - * @ingroup ov_runtime_cpp_prop_api - */ -static constexpr Property value_cache_precision{"VALUE_CACHE_PRECISION"}; - -/** - * @brief The group_size of key cache compression - * @ingroup ov_runtime_cpp_prop_api - */ -static constexpr Property key_cache_group_size{"KEY_CACHE_GROUP_SIZE"}; - -/** - * @brief The group_size of value cache compression - * @ingroup ov_runtime_cpp_prop_api - */ -static constexpr Property value_cache_group_size{"VALUE_CACHE_GROUP_SIZE"}; } // namespace ov diff --git a/src/plugins/intel_cpu/src/compiled_model.cpp b/src/plugins/intel_cpu/src/compiled_model.cpp index 2fb0cba78d19ce..14a9b6e41c516f 100644 --- a/src/plugins/intel_cpu/src/compiled_model.cpp +++ b/src/plugins/intel_cpu/src/compiled_model.cpp @@ -256,10 +256,6 @@ ov::Any CompiledModel::get_property(const std::string& name) const { RO_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RO_property(ov::hint::dynamic_quantization_group_size.name()), RO_property(ov::hint::kv_cache_precision.name()), - RO_property(ov::key_cache_precision.name()), - RO_property(ov::value_cache_precision.name()), - RO_property(ov::key_cache_group_size.name()), - RO_property(ov::value_cache_group_size.name()), }; return ro_properties; @@ -317,14 +313,6 @@ ov::Any CompiledModel::get_property(const std::string& name) const { return decltype(ov::hint::dynamic_quantization_group_size)::value_type(config.fcDynamicQuantizationGroupSize); } else if (name == ov::hint::kv_cache_precision) { return decltype(ov::hint::kv_cache_precision)::value_type(config.kvCachePrecision); - } else if (name == ov::key_cache_precision) { - return decltype(ov::key_cache_precision)::value_type(config.keyCachePrecision); - } else if (name == ov::value_cache_precision) { - return decltype(ov::value_cache_precision)::value_type(config.valueCachePrecision); - } else if (name == ov::key_cache_group_size) { - return decltype(ov::key_cache_group_size)::value_type(config.keyCacheGroupSize); - } else if (name == ov::value_cache_group_size) { - return decltype(ov::value_cache_group_size)::value_type(config.valueCacheGroupSize); } OPENVINO_THROW("Unsupported property: ", name); } diff --git a/src/plugins/intel_cpu/src/config.cpp b/src/plugins/intel_cpu/src/config.cpp index 2d2df49d1876ad..9937c3fe1e82fa 100644 --- a/src/plugins/intel_cpu/src/config.cpp +++ b/src/plugins/intel_cpu/src/config.cpp @@ -309,60 +309,6 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { ov::hint::kv_cache_precision.name(), ". Supported values: u8, bf16, f16, f32"); } - } else if (key == ov::key_cache_precision.name()) { - try { - keyCachePrecisionSetExplicitly = true; - auto const prec = val.as(); - if (one_of(prec, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8)) { - keyCachePrecision = prec; - } else { - OPENVINO_THROW("keyCachePrecision doesn't support value ", prec); - } - } catch (ov::Exception&) { - OPENVINO_THROW("Wrong value ", - val.as(), - " for property key ", - ov::key_cache_precision.name(), - ". Supported values: u8, bf16, f16, f32"); - } - } else if (key == ov::value_cache_precision.name()) { - try { - valueCachePrecisionSetExplicitly = true; - auto const prec = val.as(); - if (one_of(prec, - ov::element::f32, - ov::element::f16, - ov::element::bf16, - ov::element::u8, - ov::element::u4)) { - valueCachePrecision = prec; - } else { - OPENVINO_THROW("valueCachePrecision doesn't support value ", prec); - } - } catch (ov::Exception&) { - OPENVINO_THROW("Wrong value ", - val.as(), - " for property key ", - ov::value_cache_precision.name(), - ". Supported values: u4, u8, bf16, f16, f32"); - } - } else if (key == ov::key_cache_group_size.name() || key == ov::value_cache_group_size.name()) { - try { - auto const groupSize = val.as(); - if (key == ov::key_cache_group_size.name()) { - keyCacheGroupSizeSetExplicitly = true; - keyCacheGroupSize = groupSize; - } else { - valueCacheGroupSizeSetExplicitly = true; - valueCacheGroupSize = groupSize; - } - } catch (ov::Exception&) { - OPENVINO_THROW("Wrong value ", - val.as(), - " for property key ", - key, - ". Expected only unsinged integer numbers"); - } } else if (key == ov::cache_encryption_callbacks.name()) { try { const auto& encryption_callbacks = val.as(); @@ -398,13 +344,6 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { aclFastMath = true; } #endif - // key/value cache precision has higher priority, if not defined use kvCachePrecision - if (!keyCachePrecisionSetExplicitly && kvCachePrecisionSetExplicitly) { - keyCachePrecision = kvCachePrecision; - } - if (!valueCachePrecisionSetExplicitly && kvCachePrecisionSetExplicitly) { - valueCachePrecision = kvCachePrecision; - } // disable dynamic quantization and kv quantization for best accuracy if (executionMode == ov::hint::ExecutionMode::ACCURACY) { if (!fcDynamicQuantizationGroupSizeSetExplicitly) { @@ -413,12 +352,6 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { if (!kvCachePrecisionSetExplicitly) { kvCachePrecision = ov::element::f32; } - if (!keyCachePrecisionSetExplicitly) { - keyCachePrecision = ov::element::f32; - } - if (!valueCachePrecisionSetExplicitly) { - valueCachePrecision = ov::element::f32; - } } if (!prop.empty()) @@ -465,7 +398,7 @@ void Config::applyRtInfo(const std::shared_ptr& model) { // if user sets explicitly, it will be higher priority than rt_info if (!kvCachePrecisionSetExplicitly && model->has_rt_info({"runtime_options", ov::hint::kv_cache_precision.name()})) { - this->kvCachePrecision = this->keyCachePrecision = this->valueCachePrecision = + this->kvCachePrecision = model->get_rt_info({"runtime_options", ov::hint::kv_cache_precision.name()}); } if (!fcDynamicQuantizationGroupSizeSetExplicitly && @@ -473,23 +406,6 @@ void Config::applyRtInfo(const std::shared_ptr& model) { this->fcDynamicQuantizationGroupSize = model->get_rt_info({"runtime_options", ov::hint::dynamic_quantization_group_size.name()}); } - if (!keyCachePrecisionSetExplicitly && model->has_rt_info({"runtime_options", ov::key_cache_precision.name()})) { - this->keyCachePrecision = - model->get_rt_info({"runtime_options", ov::key_cache_precision.name()}); - } - if (!valueCachePrecisionSetExplicitly && - model->has_rt_info({"runtime_options", ov::value_cache_precision.name()})) { - this->valueCachePrecision = - model->get_rt_info({"runtime_options", ov::value_cache_precision.name()}); - } - if (!keyCacheGroupSizeSetExplicitly && model->has_rt_info({"runtime_options", ov::key_cache_group_size.name()})) { - this->keyCacheGroupSize = model->get_rt_info({"runtime_options", ov::key_cache_group_size.name()}); - } - if (!valueCacheGroupSizeSetExplicitly && - model->has_rt_info({"runtime_options", ov::value_cache_group_size.name()})) { - this->valueCacheGroupSize = - model->get_rt_info({"runtime_options", ov::value_cache_group_size.name()}); - } } } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/config.h b/src/plugins/intel_cpu/src/config.h index 8dad1853855c56..5a347b1fa30c94 100644 --- a/src/plugins/intel_cpu/src/config.h +++ b/src/plugins/intel_cpu/src/config.h @@ -48,27 +48,17 @@ struct Config { uint64_t fcDynamicQuantizationGroupSize = 32; bool fcDynamicQuantizationGroupSizeSetExplicitly = false; bool kvCachePrecisionSetExplicitly = false; - bool keyCachePrecisionSetExplicitly = false; - bool valueCachePrecisionSetExplicitly = false; - bool keyCacheGroupSizeSetExplicitly = false; - bool valueCacheGroupSizeSetExplicitly = false; #if defined(OV_CPU_WITH_ACL) bool aclFastMath = false; #endif #if defined(OPENVINO_ARCH_X86_64) ov::element::Type kvCachePrecision = ov::element::u8; - ov::element::Type keyCachePrecision = ov::element::u8; - ov::element::Type valueCachePrecision = ov::element::u8; size_t rtCacheCapacity = 5000ul; #else ov::element::Type kvCachePrecision = ov::element::f16; - ov::element::Type keyCachePrecision = ov::element::f16; - ov::element::Type valueCachePrecision = ov::element::f16; // TODO: Executor cache may leads to incorrect behavior on oneDNN ACL primitives size_t rtCacheCapacity = 0ul; #endif - size_t keyCacheGroupSize = 0ul; - size_t valueCacheGroupSize = 0ul; ov::threading::IStreamsExecutor::Config streamExecutorConfig; int streams = 1; bool streamsChanged = false; diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index 26282a70fcb512..095180d659142e 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -27,10 +27,10 @@ namespace XARCH { using namespace ov; template -static void find_minmax(const T* src, size_t n, float& min, float& max) { - max = -FLT_MAX; - min = FLT_MAX; +static void quant_u8(const T* src, uint8_t* dst, size_t n, float& scale, float& zp) { size_t i = 0; + float max = -FLT_MAX; + float min = FLT_MAX; #if defined(HAVE_AVX512F) auto v0_max = _mm512_set1_ps(-FLT_MAX); auto v0_min = _mm512_set1_ps(FLT_MAX); @@ -131,18 +131,12 @@ static void find_minmax(const T* src, size_t n, float& min, float& max) { max = std::max(max, tmp); min = std::min(min, tmp); } -} - -template -static void quant_u8(const T* src, uint8_t* dst, size_t n, float& scale, float& zp) { - size_t i = 0; - float max = -FLT_MAX; - float min = FLT_MAX; - find_minmax(src, n, min, max); scale = (max - min) / 255; if (scale == 0) scale = 0.0001f; zp = -min / scale; + + i = 0; #if defined(HAVE_AVX512F) auto v_scale = _mm512_set1_ps(1 / scale); auto v_zp = _mm512_set1_ps(zp); @@ -176,116 +170,6 @@ static void quant_u8(const T* src, uint8_t* dst, size_t n, float& scale, float& } } -template -static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) { - size_t i = 0; - float max = -FLT_MAX; - float min = FLT_MAX; - find_minmax(src, n, min, max); - auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t { - uint8_t shift = high_half ? 0 : 4; - return dst | (uint8_t)(val << shift); - }; - auto dst_ptr = reinterpret_cast(dst); - scale = (max - min) / ((1 << 4) - 1); - if (scale == 0) - scale = 0.0001f; - zp = -min / scale; -#if defined(HAVE_AVX512F) - auto v_scale = _mm512_set1_ps(1 / scale); - auto v_zp = _mm512_set1_ps(zp); - auto v_zero = _mm512_setzero_epi32(); - auto v_upper = _mm512_set1_epi32(15); - for (; i + 2 * vec_len_f32_avx512 <= n; i += 2 * vec_len_f32_avx512) { - auto v0 = mm512_uni_loadu_ps(src + i); - auto v1 = mm512_uni_loadu_ps(src + i + vec_len_f32_avx512); - v0 = _mm512_fmadd_ps(v0, v_scale, v_zp); - v1 = _mm512_fmadd_ps(v1, v_scale, v_zp); - auto v0_i32 = _mm512_cvt_roundps_epi32(v0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - auto v1_i32 = _mm512_cvt_roundps_epi32(v1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - v0_i32 = _mm512_max_epi32(v0_i32, v_zero); - v1_i32 = _mm512_max_epi32(v1_i32, v_zero); - v0_i32 = _mm512_min_epi32(v0_i32, v_upper); - v1_i32 = _mm512_min_epi32(v1_i32, v_upper); - __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); - __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1); - auto first_half = _mm512_permutex2var_epi32(v0_i32, idx1, v1_i32); - auto second_half = _mm512_permutex2var_epi32(v0_i32, idx2, v1_i32); - first_half = _mm512_slli_epi32(first_half, 4); - auto mask = _mm512_set1_epi32(0x0F); - second_half = _mm512_and_epi32(second_half, mask); - auto combined = _mm512_or_epi32(first_half, second_half); - _mm512_mask_cvtepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined); - } -#endif -#if defined(HAVE_AVX2) - auto v256_zero = _mm256_set1_epi32(0); - auto v256_upper = _mm256_set1_epi32(15); - auto v256_scale = _mm256_set1_ps(1 / scale); - auto v256_zp = _mm256_set1_ps(zp); - for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { - auto v0 = mm256_uni_loadu_ps(src + i); - auto v1 = mm256_uni_loadu_ps(src + i + vec_len_f32_avx2); - v0 = _mm256_fmadd_ps(v0, v256_scale, v256_zp); - v1 = _mm256_fmadd_ps(v1, v256_scale, v256_zp); - v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); - v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); - - auto v0_i32 = _mm256_cvtps_epi32(v0); - auto v1_i32 = _mm256_cvtps_epi32(v1); - v0_i32 = _mm256_max_epi32(v0_i32, v256_zero); - v1_i32 = _mm256_max_epi32(v1_i32, v256_zero); - v0_i32 = _mm256_min_epi32(v0_i32, v256_upper); - v1_i32 = _mm256_min_epi32(v1_i32, v256_upper); - auto idx1 = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); - v0_i32 = _mm256_permutevar8x32_epi32(v0_i32, idx1); - v1_i32 = _mm256_permutevar8x32_epi32(v1_i32, idx1); - // 0,1,2,3,4,5,6,7 | 8,9,10,11,12,13,14,15 - // _mm256_permutevar8x32_epi32 - // 0,2,4,6,1,3,5,7 | 8,10,12,14,9,11,13,15 - // _mm256_permute2x128_si256 - // 0,2,4,6,8,10,12,14 | 1,3,5,7,9,11,13,15 - // shift + mask + or - // [0,1],[2,3], ..., [12,13], [14,15] - auto first_half = _mm256_permute2x128_si256(v0_i32, v1_i32, 0x20); - auto second_half = _mm256_permute2x128_si256(v0_i32, v1_i32, 0x31); - first_half = _mm256_slli_epi32(first_half, 4); - auto mask = _mm256_set1_epi32(0x0F); - second_half = _mm256_and_si256(second_half, mask); - auto combined = _mm256_or_si256(first_half, second_half); - - auto high4 = _mm256_extractf128_si256(combined, 1); - auto low4 = _mm256_castsi256_si128(combined); - // ignore sign bit for u4 case - auto packed = _mm_packus_epi32(low4, high4); - packed = _mm_packus_epi16(packed, packed); - _mm_storel_epi64(reinterpret_cast<__m128i*>(dst_ptr + i / 2), packed); - } -#endif - for (; i < n; i++) { - float tmp = src[i]; -#define MIN(a, b) ((a) < (b) ? (a) : (b)) - uint8_t src_val = MIN(15, (uint8_t)(std::round(tmp / scale + zp))); - uint8_t dst_val = i % 2 == 0 ? 0 : dst_ptr[i / 2]; - dst_val = insert_half_byte(dst_val, src_val, (uint8_t)(i % 2)); - dst_ptr[i / 2] = dst_val; - } -} - -template ::type = true> -static void quantize(const T* src, uint8_t* dst, size_t n, float* scale_zp) { - quant_u8(src, dst, n, *scale_zp, *(scale_zp + 1)); -} - -template ::type = true> -static void quantize(const T* src, void* dst, size_t n, float* scale_zp) { - quant_u4(src, dst, n, *scale_zp, *(scale_zp + 1)); -} - template static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, @@ -303,55 +187,36 @@ static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, }); } -template +template static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, const ov::intel_cpu::PlainTensor& v_dst, - const ov::intel_cpu::PlainTensor& slot_mapping, - const size_t key_group_size, - const size_t value_group_size) { + const ov::intel_cpu::PlainTensor& slot_mapping) { size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3], SV = v_src.m_dims[3]; size_t block_size = k_dst.m_dims[2]; - size_t sub_byte_multiplier = 8 / v_dst.get_precision().bitwidth(); parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { auto slot = slot_mapping.ptr(b)[m]; if (slot < 0) return; auto block_number = slot / block_size; auto block_offset = slot % block_size; + + auto p_k = reinterpret_cast(k_dst.ptr(block_number, h, block_offset)); + auto p_v = reinterpret_cast(v_dst.ptr(block_number, h, block_offset)); // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| - for (size_t src_offset = 0, dst_offset = 0; src_offset < S; - src_offset += key_group_size, dst_offset += key_group_size + sizeof(float) + sizeof(float)) { - auto p_k = reinterpret_cast( - k_dst.ptr::value_type>(block_number, - h, - block_offset, - dst_offset)); - quantize( - k_src.ptr(b, h, m, src_offset), - k_dst.ptr::value_type>(block_number, - h, - block_offset, - dst_offset) + - sizeof(float) + sizeof(float), - key_group_size, - p_k); - } - - for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += value_group_size, - dst_offset += value_group_size / sub_byte_multiplier + sizeof(float) + sizeof(float)) { - uint8_t* v_base = reinterpret_cast( - v_dst.m_ptr.get() + - (block_number * v_dst.m_strides[0] + h * v_dst.m_strides[1] + block_offset * v_dst.m_strides[2]) / - sub_byte_multiplier + - dst_offset); - auto p_v = reinterpret_cast(v_base); - uint8_t* v_ptr = v_base + sizeof(float) * 2; - quantize(v_src.ptr(b, h, m, src_offset), v_ptr, value_group_size, p_v); - } + quant_u8(k_src.ptr(b, h, m), + k_dst.ptr(block_number, h, block_offset) + sizeof(float) + sizeof(float), + S, + p_k[0], + p_k[1]); + quant_u8(v_src.ptr(b, h, m), + v_dst.ptr(block_number, h, block_offset) + sizeof(float) + sizeof(float), + SV, + p_v[0], + p_v[1]); }); } @@ -380,48 +245,20 @@ void paged_attn_quantkv(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, const ov::intel_cpu::PlainTensor& v_dst, - const ov::intel_cpu::PlainTensor& slot_mapping, - const size_t key_group_size, - const size_t value_group_size) { - using function_type = void (*)(const ov::intel_cpu::PlainTensor&, - const ov::intel_cpu::PlainTensor&, - const ov::intel_cpu::PlainTensor&, - const ov::intel_cpu::PlainTensor&, - const ov::intel_cpu::PlainTensor&, - const size_t, - const size_t); - static constexpr function_type funcs_fp32[] = { - paged_attn_quant_mt, - paged_attn_quant_mt, - }; - static constexpr function_type funcs_bf16[] = { - paged_attn_quant_mt, - paged_attn_quant_mt, - }; - static constexpr function_type funcs_f16[] = { - paged_attn_quant_mt, - paged_attn_quant_mt, - }; - if (k_dst.get_precision() != ov::element::u8) { + const ov::intel_cpu::PlainTensor& slot_mapping) { + if (k_src.get_precision() == ov::element::f32 && k_dst.get_precision() == ov::element::u8) { + paged_attn_quant_mt(k_src, v_src, k_dst, v_dst, slot_mapping); + } else if (k_src.get_precision() == ov::element::bf16 && k_dst.get_precision() == ov::element::u8) { + paged_attn_quant_mt(k_src, v_src, k_dst, v_dst, slot_mapping); + } else if (k_src.get_precision() == ov::element::f16 && k_dst.get_precision() == ov::element::u8) { + paged_attn_quant_mt(k_src, v_src, k_dst, v_dst, slot_mapping); + } else { OPENVINO_THROW("unsupport src type: ", k_src.get_precision(), ", dst type: ", k_dst.get_precision(), " in paged_attn_quantkv"); } - std::map dispatch_table = { - {ov::element::u8, 0}, - {ov::element::u4, 1}, - {ov::element::i4, 2}, - }; - size_t dispatch = dispatch_table[v_dst.get_precision()]; - if (k_src.get_precision() == ov::element::f32) { - funcs_fp32[dispatch](k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); - } else if (k_src.get_precision() == ov::element::bf16) { - funcs_bf16[dispatch](k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); - } else if (k_src.get_precision() == ov::element::f16) { - funcs_f16[dispatch](k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); - } } void attn_quant_u8(const float* src, uint8_t* dst, size_t n, float& scale, float& zp) { @@ -429,7 +266,7 @@ void attn_quant_u8(const float* src, uint8_t* dst, size_t n, float& scale, float } void attn_dequant_u8(const uint8_t* src, float* dst, size_t n, float scale, float zp) { - attn_dequant_kernel(src, dst, n, scale, zp); + attn_dequant_u8_kernel(src, dst, n, scale, zp); } } // namespace XARCH diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.hpp index 364e5775861ed2..2f39f74f5b3460 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.hpp @@ -27,9 +27,7 @@ void paged_attn_quantkv(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, const ov::intel_cpu::PlainTensor& v_dst, - const ov::intel_cpu::PlainTensor& slot_mapping, - const size_t key_group_size, - const size_t value_group_size); + const ov::intel_cpu::PlainTensor& slot_mapping); void attn_quant_u8(const float* src, uint8_t* dst, size_t n, float& scale, float& zp); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index 761a136eda2997..759d0005103871 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -17,10 +17,8 @@ namespace Extensions { namespace Cpu { namespace XARCH { -template ::type = true> -void attn_dequant_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { +template +void attn_dequant_u8_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { size_t i = 0; // loadu_si128/epi64 does not support const qualifier uint8_t* src_nc = const_cast(src); @@ -54,90 +52,6 @@ void attn_dequant_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, f } } -template ::type = true> -void attn_dequant_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { - // 2 4bit data form a byte - /* 0,1|2,3|4,5|6,7 - / \ - 0,2,4,6|1,3,5,7 - | - permute - | - 0,1,2,3,4,5,6,7 - */ - size_t i = 0; - uint8_t* src_nc = const_cast(src); -#if defined(HAVE_AVX512F) - auto v_scale = _mm512_set1_ps(scale); - auto v_zp_scale = _mm512_set1_ps(zp * scale); - for (; i + vec_len_f32_avx512 * 2 <= n; i += vec_len_f32_avx512 * 2) { - auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(src_nc + i / 2)); - auto v_i32 = _mm512_cvtepu8_epi32(data); - - auto v_512_low_half = _mm512_srli_epi32(v_i32, 4); - auto v_f32_low_half = _mm512_cvtepi32_ps(v_512_low_half); - - auto mask = _mm512_set1_epi32(0x0F); - auto v_512_high_half = _mm512_and_si512(v_i32, mask); - auto v_f32_high_half = _mm512_cvtepi32_ps(v_512_high_half); - // q * scale- zp * scale - v_f32_low_half = _mm512_fmsub_ps(v_f32_low_half, v_scale, v_zp_scale); - v_f32_high_half = _mm512_fmsub_ps(v_f32_high_half, v_scale, v_zp_scale); - __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); - __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); - __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); - __m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half); - mm512_uni_storeu_ps(dst + i, first_half); - mm512_uni_storeu_ps(dst + i + vec_len_f32_avx512, second_half); - } -#elif defined(HAVE_AVX2) - auto v256_zp = _mm256_set1_ps(zp); - auto v256_scale = _mm256_set1_ps(scale); - for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { - auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(src_nc + i / 2)); - - auto v_i32 = _mm256_cvtepu8_epi32(data); - auto v_256_low_half = _mm256_srli_epi32(v_i32, 4); - auto v_f32_low_half = _mm256_cvtepi32_ps(v_256_low_half); - - auto mask = _mm256_set1_epi32(0x0F); - auto v_256_high_half = _mm256_and_si256(v_i32, mask); - auto v_f32_high_half = _mm256_cvtepi32_ps(v_256_high_half); - // q - zp - v_f32_low_half = _mm256_sub_ps(v_f32_low_half, v256_zp); - v_f32_high_half = _mm256_sub_ps(v_f32_high_half, v256_zp); - - v_f32_low_half = _mm256_mul_ps(v_f32_low_half, v256_scale); - v_f32_high_half = _mm256_mul_ps(v_f32_high_half, v256_scale); - - // 0,2,4,6,8,10,12,14 | 1,3,5,7,9,11,13,15 - // _mm256_permute2f128_ps - // 0,2,4,6,1,3,5,7 | 8,10,12,14,9,11,13,15 - // _mm256_permutevar8x32_ps - // 0,1,2,3,4,5,6,7 | 8,9,10,11,12,13,14,15 - __m256 first_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x20); - auto idx1 = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); - first_half = _mm256_permutevar8x32_ps(first_half, idx1); - __m256 second_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x31); - second_half = _mm256_permutevar8x32_ps(second_half, idx1); - - mm256_uni_storeu_ps(dst + i, first_half); - mm256_uni_storeu_ps(dst + i + vec_len_f32_avx2, second_half); - } -#endif - auto extract_half_byte = [&](uint8_t val, bool high_half) -> uint8_t { - uint8_t shift = high_half ? 0 : 4; - return (uint8_t)((val >> shift) & 0x000F); - }; - for (; i < n; ++i) { - float tmp = extract_half_byte(src_nc[i / 2], (uint8_t)(i % 2)); - tmp = (tmp - zp) * scale; - dst[i] = tmp; - } -} - } // namespace XARCH } // namespace Cpu } // namespace Extensions diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index 955e7687ef97b3..a74021d8ac0d05 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -72,22 +72,8 @@ void cvt_copy(TA* dst, TB* src, size_t n) { } } -size_t inline get_sub_byte_multiplier(ov::element::Type type) { - return one_of(type, ov::element::i4, ov::element::u4) ? 8 / type.bitwidth() : 1; -} - -template ::value || std::is_same::value || - std::is_same::value) && - (SRC_PREC != ov::element::u8 || SRC_PREC != ov::element::u4), - bool>::type = true> -static void attn_acc_value_block(float* out, - float* weight, - T* v, - const size_t S, - const size_t block_size, - const size_t group_size) { +template +static void attn_acc_value_block(float* out, float* weight, T* v, size_t S, size_t block_size) { # if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { @@ -214,262 +200,117 @@ static void attn_acc_value_block(float* out, v += S; } } -template ::type = true> -static void attn_acc_value_block(float* out, - float* weight, - uint8_t* v, - const size_t S, - const size_t block_size, - const size_t group_size) { + +static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S, size_t block_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) - size_t src_offset = 0; - size_t dst_offset = 0; - const size_t params_offset = sizeof(float) * 2; - const size_t src_stride = S / group_size * (group_size + params_offset); - # if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { - dst_offset = 0; - src_offset = 0; - while (dst_offset < S) { - // process group by group - uint8_t* v_ptr = v + src_offset; - auto v_f0 = reinterpret_cast(v_ptr); - auto v_f1 = reinterpret_cast(v_ptr + src_stride); - auto v_f2 = reinterpret_cast(v_ptr + 2 * src_stride); - auto v_f3 = reinterpret_cast(v_ptr + 3 * src_stride); - auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); - auto attn_w_vec1 = _mm512_set1_ps(weight[1] * v_f1[0]); - auto attn_w_vec2 = _mm512_set1_ps(weight[2] * v_f2[0]); - auto attn_w_vec3 = _mm512_set1_ps(weight[3] * v_f3[0]); - auto zp0 = _mm512_set1_ps(v_f0[1]); - auto zp1 = _mm512_set1_ps(v_f1[1]); - auto zp2 = _mm512_set1_ps(v_f2[1]); - auto zp3 = _mm512_set1_ps(v_f3[1]); - uint8_t* v_data_ptr = v + src_offset + params_offset; - size_t i = 0; - for (; i + vec_len_f32_avx512 <= group_size; i += vec_len_f32_avx512) { - auto v_out = mm512_uni_loadu_ps(out + dst_offset + i); - auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( - _mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i)))), - zp0); - auto v1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( - _mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i + src_stride)))), - zp1); - auto v2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128( - reinterpret_cast<__m128i*>(v_data_ptr + i + 2 * src_stride)))), - zp2); - auto v3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128( - reinterpret_cast<__m128i*>(v_data_ptr + i + 3 * src_stride)))), - zp3); - v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); - v_out = _mm512_fmadd_ps(attn_w_vec1, v1, v_out); - v_out = _mm512_fmadd_ps(attn_w_vec2, v2, v_out); - v_out = _mm512_fmadd_ps(attn_w_vec3, v3, v_out); - _mm512_storeu_ps(out + dst_offset + i, v_out); - } - for (; i < group_size; i++) { - out[i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; - out[i] += weight[1] * (v_data_ptr[i + src_stride] - v_f1[1]) * v_f1[0]; - out[i] += weight[2] * (v_data_ptr[i + 2 * src_stride] - v_f2[1]) * v_f2[0]; - out[i] += weight[3] * (v_data_ptr[i + 3 * src_stride] - v_f3[1]) * v_f3[0]; - } - dst_offset += group_size; - src_offset += group_size + params_offset; + auto v_f0 = reinterpret_cast(v); + auto v_f1 = reinterpret_cast(v + S + 8); + auto v_f2 = reinterpret_cast(v + 2 * (S + 8)); + auto v_f3 = reinterpret_cast(v + 3 * (S + 8)); + auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); + auto attn_w_vec1 = _mm512_set1_ps(weight[1] * v_f1[0]); + auto attn_w_vec2 = _mm512_set1_ps(weight[2] * v_f2[0]); + auto attn_w_vec3 = _mm512_set1_ps(weight[3] * v_f3[0]); + auto zp0 = _mm512_set1_ps(v_f0[1]); + auto zp1 = _mm512_set1_ps(v_f1[1]); + auto zp2 = _mm512_set1_ps(v_f2[1]); + auto zp3 = _mm512_set1_ps(v_f3[1]); + size_t i = 0; + v += 8; + for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { + auto v_out = mm512_uni_loadu_ps(out + i); + auto v0 = _mm512_sub_ps( + _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i)))), + zp0); + auto v1 = _mm512_sub_ps( + _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i + S + 8)))), + zp1); + auto v2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i + 2 * (S + 8))))), + zp2); + auto v3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i + 3 * (S + 8))))), + zp3); + v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); + v_out = _mm512_fmadd_ps(attn_w_vec1, v1, v_out); + v_out = _mm512_fmadd_ps(attn_w_vec2, v2, v_out); + v_out = _mm512_fmadd_ps(attn_w_vec3, v3, v_out); + + _mm512_storeu_ps(out + i, v_out); } + for (; i < S; i++) { + out[i] += weight[0] * (v[i] - v_f0[1]) * v_f0[0]; + out[i] += weight[1] * (v[i + S + 8] - v_f1[1]) * v_f1[0]; + out[i] += weight[2] * (v[i + 2 * (S + 8)] - v_f2[1]) * v_f2[0]; + out[i] += weight[3] * (v[i + 3 * (S + 8)] - v_f3[1]) * v_f3[0]; + } + v += 4 * (S + 8) - 8; weight += 4; - v += 4 * src_stride; } for (; j < block_size; j++) { - dst_offset = 0; - src_offset = 0; - while (dst_offset < S) { - uint8_t* v_ptr = v + src_offset; - uint8_t* v_data_ptr = v_ptr + params_offset; - auto v_f0 = reinterpret_cast(v_ptr); - auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); - auto zp0 = _mm512_set1_ps(v_f0[1]); - size_t i = 0; - for (; i + vec_len_f32_avx512 <= group_size; i += vec_len_f32_avx512) { - auto v_out = mm512_uni_loadu_ps((out + dst_offset + i)); - auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( - _mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i)))), - zp0); - v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); - - _mm512_storeu_ps((out + dst_offset + i), v_out); - } - for (; i < group_size; i++) { - out[dst_offset + i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; - } - dst_offset += group_size; - src_offset += group_size + params_offset; + auto v_f0 = reinterpret_cast(v); + auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); + auto zp0 = _mm512_set1_ps(v_f0[1]); + size_t i = 0; + v += 8; + for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { + auto v_out = mm512_uni_loadu_ps(out + i); + auto v0 = _mm512_sub_ps( + _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i)))), + zp0); + v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); + + _mm512_storeu_ps(out + i, v_out); + } + for (; i < S; i++) { + out[i] += weight[0] * (v[i] - v_f0[1]) * v_f0[0]; } - v += src_stride; + v += S; weight++; } return; # elif defined(HAVE_AVX2) size_t j = 0; for (; j < block_size; j++) { - dst_offset = 0; - src_offset = 0; - while (dst_offset < S) { - uint8_t* v_ptr = v + src_offset; - uint8_t* v_data_ptr = v_ptr + params_offset; - auto v_f0 = reinterpret_cast(v_ptr); - auto attn_w_vec0 = _mm256_set1_ps(weight[0] * v_f0[0]); - auto zp0 = _mm256_set1_ps(v_f0[1]); - size_t i = 0; - v += 8; - for (; i + vec_len_f32_avx2 <= group_size; i += vec_len_f32_avx2) { - auto v_out = mm256_uni_loadu_ps(out + dst_offset + i); - auto v0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast<__m128i*>(v_data_ptr + i)))), - zp0); - v_out = _mm256_fmadd_ps(attn_w_vec0, v0, v_out); - - mm256_uni_storeu_ps(out + dst_offset + i, v_out); - } - for (; i < group_size; i++) { - out[dst_offset + i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; - } - dst_offset += group_size; - src_offset += group_size + params_offset; + auto v_f0 = reinterpret_cast(v); + auto attn_w_vec0 = _mm256_set1_ps(weight[0] * v_f0[0]); + auto zp0 = _mm256_set1_ps(v_f0[1]); + size_t i = 0; + v += 8; + for (; i + vec_len_f32_avx2 <= S; i += vec_len_f32_avx2) { + auto v_out = mm256_uni_loadu_ps(out + i); + auto v0 = _mm256_sub_ps( + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(v + i)))), + zp0); + v_out = _mm256_fmadd_ps(attn_w_vec0, v0, v_out); + + mm256_uni_storeu_ps(out + i, v_out); } - v += src_stride; + for (; i < S; i++) { + out[i] += weight[0] * (v[i] - v_f0[1]) * v_f0[0]; + } + v += S; weight++; } return; # endif for (size_t j = 0; j < block_size; j++) { - dst_offset = 0; - src_offset = 0; - while (dst_offset < S) { - auto v0 = reinterpret_cast(v + src_offset); - for (size_t i = 0; i < group_size; i++) { - out[dst_offset + i] += weight[j] * (v[i + src_offset + params_offset] - v0[1]) * v0[0]; - } - dst_offset += group_size; - src_offset += group_size + params_offset; - } - v += src_stride; - } -} - -template ::type = true> -static void attn_acc_value_block(float* out, - float* weight, - void* v, - const size_t S, - const size_t block_size, - const size_t group_size) { - size_t src_offset = 0; - size_t dst_offset = 0; - const size_t params_offset = sizeof(float) * 2; - uint8_t* v_ptr = reinterpret_cast(v); - auto sub_byte_multiplier = 8 / 4; - const size_t src_stride = S / group_size * (group_size / sub_byte_multiplier + params_offset); - auto extract_half_byte = [](uint8_t val, bool high_half) -> uint8_t { - uint8_t shift = high_half ? 0 : 4; - - return (uint8_t)((val >> shift) & 0x000F); - }; - for (size_t j = 0; j < block_size; j++) { - dst_offset = 0; - src_offset = 0; - while (dst_offset < S) { - auto v0 = reinterpret_cast(v_ptr + src_offset); - size_t i = 0; -# if defined(HAVE_AVX512F) - auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]); - auto v_zp = _mm512_set1_ps(v0[1]); - for (; i + vec_len_f32_avx512 * 2 <= group_size; i += vec_len_f32_avx512 * 2) { - auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v_ptr + i / 2 + src_offset + params_offset)); - auto v_i32 = _mm512_cvtepu8_epi32(data); - - auto v_512_low_half = _mm512_srli_epi32(v_i32, 4); - auto v_f32_low_half = _mm512_cvtepi32_ps(v_512_low_half); - - auto mask = _mm512_set1_epi32(0x0F); - auto v_512_high_half = _mm512_and_si512(v_i32, mask); - auto v_f32_high_half = _mm512_cvtepi32_ps(v_512_high_half); - - // q - zp - v_f32_low_half = _mm512_sub_ps(v_f32_low_half, v_zp); - v_f32_high_half = _mm512_sub_ps(v_f32_high_half, v_zp); - - __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); - __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); - __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); - __m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half); - auto v_out0 = mm512_uni_loadu_ps(out + dst_offset + i); - auto v_out1 = mm512_uni_loadu_ps(out + dst_offset + i + vec_len_f32_avx512); - v_out0 = _mm512_fmadd_ps(attn_w_vec0, first_half, v_out0); - v_out1 = _mm512_fmadd_ps(attn_w_vec0, second_half, v_out1); - mm512_uni_storeu_ps(out + dst_offset + i, v_out0); - mm512_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx512, v_out1); - } -# elif defined(HAVE_AVX2) - auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); - auto v256_zp = _mm256_set1_ps(v0[1]); - for (; i + vec_len_f32_avx2 * 2 <= group_size; i += vec_len_f32_avx2 * 2) { - auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v_ptr + i / 2 + src_offset + params_offset)); - - auto v_i32 = _mm256_cvtepu8_epi32(data); - auto v_256_low_half = _mm256_srli_epi32(v_i32, 4); - auto v_f32_low_half = _mm256_cvtepi32_ps(v_256_low_half); - - auto mask = _mm256_set1_epi32(0x0F); - auto v_256_high_half = _mm256_and_si256(v_i32, mask); - auto v_f32_high_half = _mm256_cvtepi32_ps(v_256_high_half); - // q - zp - v_f32_low_half = _mm256_sub_ps(v_f32_low_half, v256_zp); - v_f32_high_half = _mm256_sub_ps(v_f32_high_half, v256_zp); - - auto v_out0 = mm256_uni_loadu_ps(out + dst_offset + i); - auto v_out1 = mm256_uni_loadu_ps(out + dst_offset + i + vec_len_f32_avx2); - - __m256 first_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x20); - auto idx1 = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); - first_half = _mm256_permutevar8x32_ps(first_half, idx1); - __m256 second_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x31); - second_half = _mm256_permutevar8x32_ps(second_half, idx1); - - v_out0 = _mm256_fmadd_ps(v256_attn_w_vec0, first_half, v_out0); - v_out1 = _mm256_fmadd_ps(v256_attn_w_vec0, second_half, v_out1); - mm256_uni_storeu_ps(out + dst_offset + i, v_out0); - mm256_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx2, v_out1); - } -# endif - for (; i < group_size; i += 2) { - uint8_t data = v_ptr[i / 2 + src_offset + params_offset]; - float tmp0 = extract_half_byte(data, static_cast(i % 2)); - float tmp1 = extract_half_byte(data, static_cast((i + 1) % 2)); - out[dst_offset + i] += weight[j] * (tmp0 - v0[1]) * v0[0]; - out[dst_offset + i + 1] += weight[j] * (tmp1 - v0[1]) * v0[0]; - } - dst_offset += group_size; - src_offset += group_size / sub_byte_multiplier + params_offset; + auto v0 = reinterpret_cast(v); + v += 8; + for (size_t i = 0; i < S; i++) { + out[i] += weight[j] * (v[i] - v0[1]) * v0[0]; } - v_ptr += src_stride; + v += S; } } template -static void dot_product_block(TA* a, - TB* b, - float* c, - const size_t n, - const size_t block_size, - const size_t group_size) { +static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_size) { # if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { @@ -581,235 +422,175 @@ static void dot_product_block(TA* a, } template -static void dot_product_block(TA* a, - uint8_t* b, - float* c, - const size_t n, - const size_t block_size, - const size_t group_size) { +static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t block_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) - size_t src_offset = 0; - size_t dst_offset = 0; - const size_t params_offset = sizeof(float) * 2; - const size_t src_stride = n / group_size * (group_size + params_offset); # if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { - src_offset = 0; - dst_offset = 0; - float sum0 = 0.0f; - float sum1 = 0.0f; - float sum2 = 0.0f; - float sum3 = 0.0f; - while (dst_offset < n) { - auto vsum0 = _mm512_setzero_ps(); - auto vsum1 = _mm512_setzero_ps(); - auto vsum2 = _mm512_setzero_ps(); - auto vsum3 = _mm512_setzero_ps(); - auto b0 = reinterpret_cast(b + src_offset); - auto b1 = reinterpret_cast(b + src_offset + src_stride); - auto b2 = reinterpret_cast(b + src_offset + src_stride * 2); - auto b3 = reinterpret_cast(b + src_offset + src_stride * 3); - auto v_zp0 = _mm512_set1_ps(b0[1]); - auto v_zp1 = _mm512_set1_ps(b1[1]); - auto v_zp2 = _mm512_set1_ps(b2[1]); - auto v_zp3 = _mm512_set1_ps(b3[1]); - size_t i = 0; - uint8_t* b_data_ptr = b + src_offset + params_offset; - for (; i + vec_len_f32_avx512 <= group_size; i += vec_len_f32_avx512) { - auto va = mm512_uni_loadu_ps(a + dst_offset + i); - auto vb0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( - _mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i)))), - v_zp0); - auto vb1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( - _mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i + src_stride)))), - v_zp1); - auto vb2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128( - reinterpret_cast<__m128i*>(b_data_ptr + i + 2 * src_stride)))), - v_zp2); - auto vb3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128( - reinterpret_cast<__m128i*>(b_data_ptr + i + 3 * src_stride)))), - v_zp3); - - vsum0 = _mm512_fmadd_ps(va, vb0, vsum0); - vsum1 = _mm512_fmadd_ps(va, vb1, vsum1); - vsum2 = _mm512_fmadd_ps(va, vb2, vsum2); - vsum3 = _mm512_fmadd_ps(va, vb3, vsum3); - } - float group_sum0 = _mm512_reduce_add_ps(vsum0); - float group_sum1 = _mm512_reduce_add_ps(vsum1); - float group_sum2 = _mm512_reduce_add_ps(vsum2); - float group_sum3 = _mm512_reduce_add_ps(vsum3); - for (; i < group_size; i++) { - group_sum0 += a[i + dst_offset] * (b_data_ptr[i] - b0[1]); - group_sum1 += a[i + dst_offset] * (b_data_ptr[i + src_stride] - b1[1]); - group_sum2 += a[i + dst_offset] * (b_data_ptr[i + 2 * src_stride] - b2[1]); - group_sum3 += a[i + dst_offset] * (b_data_ptr[i + 3 * src_stride] - b3[1]); - } - sum0 += group_sum0 * b0[0]; - sum1 += group_sum1 * b1[0]; - sum2 += group_sum2 * b2[0]; - sum3 += group_sum3 * b3[0]; - dst_offset += group_size; - src_offset += group_size + params_offset; + auto vsum0 = _mm512_setzero_ps(); + auto vsum1 = _mm512_setzero_ps(); + auto vsum2 = _mm512_setzero_ps(); + auto vsum3 = _mm512_setzero_ps(); + auto b0 = reinterpret_cast(b); + auto b1 = reinterpret_cast(b + n + 8); + auto b2 = reinterpret_cast(b + (n + 8) * 2); + auto b3 = reinterpret_cast(b + (n + 8) * 3); + auto v_zp0 = _mm512_set1_ps(b0[1]); + auto v_zp1 = _mm512_set1_ps(b1[1]); + auto v_zp2 = _mm512_set1_ps(b2[1]); + auto v_zp3 = _mm512_set1_ps(b3[1]); + size_t i = 0; + b += 8; + for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { + auto va = mm512_uni_loadu_ps(a + i); + auto vb0 = _mm512_sub_ps( + _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i)))), + v_zp0); + auto vb1 = _mm512_sub_ps( + _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i + n + 8)))), + v_zp1); + auto vb2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(b + i + 2 * (n + 8))))), + v_zp2); + auto vb3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(b + i + 3 * (n + 8))))), + v_zp3); + + vsum0 = _mm512_fmadd_ps(va, vb0, vsum0); + vsum1 = _mm512_fmadd_ps(va, vb1, vsum1); + vsum2 = _mm512_fmadd_ps(va, vb2, vsum2); + vsum3 = _mm512_fmadd_ps(va, vb3, vsum3); } - c[0] = sum0; - c[1] = sum1; - c[2] = sum2; - c[3] = sum3; + float sum0 = _mm512_reduce_add_ps(vsum0); + float sum1 = _mm512_reduce_add_ps(vsum1); + float sum2 = _mm512_reduce_add_ps(vsum2); + float sum3 = _mm512_reduce_add_ps(vsum3); + for (; i < n; i++) { + sum0 += a[i] * (b[i] - b0[1]); + sum1 += a[i] * (b[i + n + 8] - b1[1]); + sum2 += a[i] * (b[i + 2 * (n + 8)] - b2[1]); + sum3 += a[i] * (b[i + 3 * (n + 8)] - b3[1]); + } + c[0] = sum0 * b0[0]; + c[1] = sum1 * b1[0]; + c[2] = sum2 * b2[0]; + c[3] = sum3 * b3[0]; c += 4; - b += 4 * src_stride; + b += 4 * (n + 8) - 8; } for (; j < block_size; j++) { - src_offset = 0; - dst_offset = 0; - float sum = 0; - while (dst_offset < n) { - auto vsum = _mm512_setzero_ps(); - auto b0 = reinterpret_cast(b + src_offset); - auto v_zp = _mm512_set1_ps(b0[1]); - size_t i = 0; - uint8_t* b_data_ptr = b + src_offset + params_offset; - for (; i + vec_len_f32_avx512 <= group_size; i += vec_len_f32_avx512) { - auto va = mm512_uni_loadu_ps(a + dst_offset + i); - auto vb = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( - _mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i)))), - v_zp); - vsum = _mm512_fmadd_ps(va, vb, vsum); - } - float group_sum = _mm512_reduce_add_ps(vsum); - for (; i < group_size; i++) { - group_sum += a[i + dst_offset] * (b_data_ptr[i] - b0[1]); - } - sum += group_sum * b0[0]; - dst_offset += group_size; - src_offset += group_size + params_offset; + auto vsum = _mm512_setzero_ps(); + auto b0 = reinterpret_cast(b); + auto v_zp = _mm512_set1_ps(b0[1]); + size_t i = 0; + b += 8; + for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { + auto va = mm512_uni_loadu_ps(a + i); + auto vb = _mm512_sub_ps( + _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i)))), + v_zp); + vsum = _mm512_fmadd_ps(va, vb, vsum); } - b += src_stride; - *c++ = sum; + float sum = _mm512_reduce_add_ps(vsum); + for (; i < n; i++) { + sum += a[i] * (b[i] - b0[1]); + } + b += n; + *c++ = sum * b0[0]; } return; # elif defined(HAVE_AVX2) size_t j = 0; for (; j + 4 <= block_size; j += 4) { - src_offset = 0; - dst_offset = 0; - float sum0 = 0.0f; - float sum1 = 0.0f; - float sum2 = 0.0f; - float sum3 = 0.0f; - while (dst_offset < n) { - auto vsum0 = _mm256_setzero_ps(); - auto vsum1 = _mm256_setzero_ps(); - auto vsum2 = _mm256_setzero_ps(); - auto vsum3 = _mm256_setzero_ps(); - auto b0 = reinterpret_cast(b + src_offset); - auto b1 = reinterpret_cast(b + src_offset + src_stride); - auto b2 = reinterpret_cast(b + src_offset + src_stride * 2); - auto b3 = reinterpret_cast(b + src_offset + src_stride * 3); - auto v_zp0 = _mm256_set1_ps(b0[1]); - auto v_zp1 = _mm256_set1_ps(b1[1]); - auto v_zp2 = _mm256_set1_ps(b2[1]); - auto v_zp3 = _mm256_set1_ps(b3[1]); - size_t i = 0; - uint8_t* b_data_ptr = b + src_offset + params_offset; - for (; i + vec_len_f32_avx2 <= group_size; i += vec_len_f32_avx2) { - auto va = mm256_uni_loadu_ps(a + dst_offset + i); - auto vb0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i)))), - v_zp0); - auto vb1 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i + src_stride)))), - v_zp1); - auto vb2 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( - reinterpret_cast<__m128i*>(b_data_ptr + i + 2 * src_stride)))), - v_zp2); - auto vb3 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( - reinterpret_cast<__m128i*>(b_data_ptr + i + 3 * src_stride)))), - v_zp3); - - vsum0 = _mm256_fmadd_ps(va, vb0, vsum0); - vsum1 = _mm256_fmadd_ps(va, vb1, vsum1); - vsum2 = _mm256_fmadd_ps(va, vb2, vsum2); - vsum3 = _mm256_fmadd_ps(va, vb3, vsum3); - } - hsum(vsum0); - hsum(vsum1); - hsum(vsum2); - hsum(vsum3); - float group_sum0 = _mm256_cvtss_f32(vsum0); - float group_sum1 = _mm256_cvtss_f32(vsum1); - float group_sum2 = _mm256_cvtss_f32(vsum2); - float group_sum3 = _mm256_cvtss_f32(vsum3); - for (; i < group_size; i++) { - group_sum0 += a[dst_offset + i] * (b[i] - b0[1]); - group_sum1 += a[dst_offset + i] * (b[i + src_stride] - b1[1]); - group_sum2 += a[dst_offset + i] * (b[i + 2 * src_stride] - b2[1]); - group_sum3 += a[dst_offset + i] * (b[i + 3 * src_stride] - b3[1]); - } - sum0 += group_sum0 * b0[0]; - sum1 += group_sum1 * b1[0]; - sum2 += group_sum2 * b2[0]; - sum3 += group_sum3 * b3[0]; - dst_offset += group_size; - src_offset += group_size + params_offset; + auto vsum0 = _mm256_setzero_ps(); + auto vsum1 = _mm256_setzero_ps(); + auto vsum2 = _mm256_setzero_ps(); + auto vsum3 = _mm256_setzero_ps(); + auto b0 = reinterpret_cast(b); + auto b1 = reinterpret_cast(b + n + 8); + auto b2 = reinterpret_cast(b + (n + 8) * 2); + auto b3 = reinterpret_cast(b + (n + 8) * 3); + auto v_zp0 = _mm256_set1_ps(b0[1]); + auto v_zp1 = _mm256_set1_ps(b1[1]); + auto v_zp2 = _mm256_set1_ps(b2[1]); + auto v_zp3 = _mm256_set1_ps(b3[1]); + size_t i = 0; + b += 8; + for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { + auto va = mm256_uni_loadu_ps(a + i); + auto vb0 = _mm256_sub_ps( + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i)))), + v_zp0); + auto vb1 = _mm256_sub_ps( + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + n + 8)))), + v_zp1); + auto vb2 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + 2 * (n + 8))))), + v_zp2); + auto vb3 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + 3 * (n + 8))))), + v_zp3); + + vsum0 = _mm256_fmadd_ps(va, vb0, vsum0); + vsum1 = _mm256_fmadd_ps(va, vb1, vsum1); + vsum2 = _mm256_fmadd_ps(va, vb2, vsum2); + vsum3 = _mm256_fmadd_ps(va, vb3, vsum3); } - c[0] = sum0; - c[1] = sum1; - c[2] = sum2; - c[3] = sum3; + hsum(vsum0); + hsum(vsum1); + hsum(vsum2); + hsum(vsum3); + float sum0 = _mm256_cvtss_f32(vsum0); + float sum1 = _mm256_cvtss_f32(vsum1); + float sum2 = _mm256_cvtss_f32(vsum2); + float sum3 = _mm256_cvtss_f32(vsum3); + for (; i < n; i++) { + sum0 += a[i] * (b[i] - b0[1]); + sum1 += a[i] * (b[i + n + 8] - b1[1]); + sum2 += a[i] * (b[i + 2 * (n + 8)] - b2[1]); + sum3 += a[i] * (b[i + 3 * (n + 8)] - b3[1]); + } + c[0] = sum0 * b0[0]; + c[1] = sum1 * b1[0]; + c[2] = sum2 * b2[0]; + c[3] = sum3 * b3[0]; c += 4; - b += 4 * src_stride; + b += 4 * (n + 8) - 8; } for (; j < block_size; j++) { - src_offset = 0; - dst_offset = 0; - float sum = 0; - while (dst_offset < n) { - auto vsum = _mm256_setzero_ps(); - auto b0 = reinterpret_cast(b + src_offset); - auto v_zp = _mm256_set1_ps(b0[1]); - size_t i = 0; - uint8_t* b_data_ptr = b + src_offset + params_offset; - for (; i + vec_len_f32_avx2 <= group_size; i += vec_len_f32_avx2) { - auto va = mm256_uni_loadu_ps(a + dst_offset + i); - auto vb = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i)))), - v_zp); - vsum = _mm256_fmadd_ps(va, vb, vsum); - } - hsum(vsum); - float group_sum = _mm256_cvtss_f32(vsum); - for (; i < group_size; i++) { - group_sum += a[i + dst_offset] * (b_data_ptr[i] - b0[1]); - } - sum += group_sum * b0[0]; - dst_offset += group_size; - src_offset += group_size + params_offset; + auto vsum = _mm256_setzero_ps(); + auto b0 = reinterpret_cast(b); + auto v_zp = _mm256_set1_ps(b0[1]); + size_t i = 0; + b += 8; + for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { + auto va = mm256_uni_loadu_ps(a + i); + auto vb = _mm256_sub_ps( + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i)))), + v_zp); + vsum = _mm256_fmadd_ps(va, vb, vsum); } - b += src_stride; - *c++ = sum; + hsum(vsum); + float sum = _mm256_cvtss_f32(vsum); + for (; i < n; i++) { + sum += a[i] * (b[i] - b0[1]); + } + b += n; + *c++ = sum * b0[0]; } return; # endif for (size_t j = 0; j < block_size; j++) { float sum = 0; - dst_offset = 0; - src_offset = 0; - while (dst_offset < n) { - auto b0 = reinterpret_cast(b + src_offset); - float group_sum = 0.0f; - for (size_t i = 0; i < group_size; i++) { - group_sum += a[dst_offset + i] * (b[src_offset + params_offset + i] - b0[1]); - } - sum += group_sum * b0[0]; - dst_offset += group_size; - src_offset += group_size + params_offset; + auto b0 = reinterpret_cast(b); + b += 8; + for (size_t i = 0; i < n; i++) { + sum += a[i] * (b[i] - b0[1]); } - b += src_stride; - *c++ = sum; + b += n; + *c++ = sum * b0[0]; } } @@ -853,138 +634,73 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str } // N must be multiple of 16 -template ::type = true> -void transpose_16NxK(TDST* dst, - void* src, - TDST* tmp, - const size_t N, - const size_t K, - const size_t dst_stride, - const size_t src_stride, - const size_t group_size) { +template +void transpose_16NxK(TDST* dst, TSRC* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { size_t k = 0; - auto* src_ptr = reinterpret_cast::value_type*>(src); for (; k + 16 <= K; k += 16) { for (size_t n = 0; n < N; n += 16) { - transpose_16x16_kernel(dst + n, src_ptr + n * src_stride, dst_stride, src_stride); + transpose_16x16_kernel(dst + n, src + n * src_stride, dst_stride, src_stride); } dst += 16 * dst_stride; - src_ptr += 16; + src += 16; } if (k < K) { for (size_t n = 0; n < N; n += 16) { - transpose_16xK_kernel(dst + n, src_ptr + n * src_stride, K - k, dst_stride, src_stride); + transpose_16xK_kernel(dst + n, src + n * src_stride, K - k, dst_stride, src_stride); } } } + # if defined(HAVE_AVX512F) template ::value), - bool>::type = true> -static void transpose_16NxK(T* dst, - T* src, - T* tmp, - const size_t N, - const size_t K, - const size_t dst_stride, - const size_t src_stride, - const size_t group_size) { + typename = typename std:: + enable_if<(std::is_same::value || std::is_same::value), bool>::type> +static void transpose_16NxK(T* dst, T* src, T* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { // will treat as uint32_t transpose auto s = reinterpret_cast(src); auto d = reinterpret_cast(dst); - transpose_16NxK(d, - s, - reinterpret_cast(0), - N, - K >> 1, - dst_stride, - src_stride >> 1, - group_size); + transpose_16NxK(d, s, reinterpret_cast(0), N, K >> 1, dst_stride, src_stride >> 1); } # endif -template ::type = true> -void transpose_16NxK(TDST* dst, - void* src, - TDST* tmp, - const size_t N, - const size_t K, - const size_t dst_stride, - const size_t src_stride, - const size_t group_size) { +template +void transpose_16NxK(TDST* dst, uint8_t* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) - auto s = reinterpret_cast::value_type*>(src); + auto s = src; auto t = tmp; - // if group_size not set, the whole row is used as a group for (size_t n = 0; n < N; n++) { - size_t src_offset = 0; - size_t dst_offset = 0; - while (dst_offset < K) { - auto f = reinterpret_cast(s + src_offset); - attn_dequant_kernel(s + src_offset + sizeof(float) * 2, - t + dst_offset, - group_size, - f[0], - f[1]); - src_offset += group_size + sizeof(float) * 2; - dst_offset += group_size; - } - s += src_offset; + auto f = reinterpret_cast(s); + attn_dequant_u8_kernel(s + 2 * sizeof(float), t, K, f[0], f[1]); + s += src_stride + 2 * sizeof(float); t += src_stride; } - transpose_16NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride, 0); + transpose_16NxK(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); } // dequant f16/u8 to float -template ::type = true> -static inline void dequant(T* dst, void* src, const size_t N, const size_t K, const size_t group_size) { +template +static inline void dequant(T* dst, T* src, size_t N, size_t K) { // never called OPENVINO_THROW("dequant: should not be called."); } -template ::type = true> -static inline void dequant(float* dst, ov::float16* src, const size_t N, const size_t K, const size_t group_size) { + +static inline void dequant(float* dst, ov::float16* src, size_t N, size_t K) { cvt_copy(dst, src, K * N); } -template ::type = true> -void dequant(TDST* dst, uint8_t* src, const size_t N, const size_t K, const size_t group_size) { +template +void dequant(TDST* dst, uint8_t* src, size_t N, size_t K) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = src; - const size_t params_offset = sizeof(float) * 2; - const size_t sub_byte_mulitplier = get_sub_byte_multiplier(SRC_PREC); - for (size_t n = 0; n < N; n++) { - size_t src_offset = 0; - size_t dst_offset = 0; - while (dst_offset < K) { - auto f = reinterpret_cast(s + src_offset); - attn_dequant_kernel(s + src_offset + params_offset, - dst + dst_offset, - group_size, - f[0], - f[1]); - src_offset += group_size / sub_byte_mulitplier + params_offset; - dst_offset += group_size; - } - s += src_offset; + auto f = reinterpret_cast(s); + attn_dequant_u8_kernel(s + 2 * sizeof(float), dst, K, f[0], f[1]); + s += K + 2 * sizeof(float); dst += K; } } @@ -1056,101 +772,54 @@ static void pack_32xK_kernel(T* dst, T* src, size_t dst_stride, size_t src_strid } } -template ::value != ov::element::f32 && - (SRC_PREC == ov::element::bf16 || SRC_PREC == ov::element::f16), - bool>::type = true> -static void pack_32NxK(TDST* dst, - void* src, - TDST* tmp, - const size_t N, - const size_t K, - const size_t dst_stride, - const size_t src_stride, - const size_t group_size) { - auto src_ptr = reinterpret_cast::value_type*>(src); +template ::value || std::is_same::value), bool>::type> +static void pack_32NxK(T* dst, T* src, T* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { for (size_t n = 0; n < N; n += 32) { size_t k = 0; for (; k + 32 <= K; k += 32) { - pack_32x32_kernel(dst + k * 2, src_ptr + k, dst_stride, src_stride); + pack_32x32_kernel(dst + k * 2, src + k, dst_stride, src_stride); } if (k + 16 <= K) { - pack_32x16_kernel(dst + k * 2, src_ptr + k, dst_stride, src_stride); + pack_32x16_kernel(dst + k * 2, src + k, dst_stride, src_stride); k += 16; } if (k < K) { - pack_32xK_kernel(dst + k * 2, src_ptr + k, dst_stride, src_stride, K - k); + pack_32xK_kernel(dst + k * 2, src + k, dst_stride, src_stride, K - k); } dst += 32 * dst_stride; - src_ptr += 32 * src_stride; + src += 32 * src_stride; } } -template ::value != ov::element::f32 && - (SRC_PREC == ov::element::u4 || SRC_PREC == ov::element::u8), - bool>::type = true> -static void pack_32NxK(TDST* dst, - void* src, - TDST* tmp, - const size_t N, - const size_t K, - const size_t dst_stride, - const size_t src_stride, - const size_t group_size) { +template ::value || std::is_same::value), bool>::type> +static void pack_32NxK(T* dst, uint8_t* src, T* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) - auto s = reinterpret_cast(src); + auto s = src; auto t = tmp; - // if group_size not set, the whole row is used as a group - const size_t sub_byte_mulitplier = get_sub_byte_multiplier(SRC_PREC); for (size_t n = 0; n < N; n++) { - size_t src_offset = 0; - size_t dst_offset = 0; - while (dst_offset < K) { - auto f = reinterpret_cast(s + src_offset); - attn_dequant_kernel(s + (src_offset + sizeof(float) * 2), - t + dst_offset, - group_size, - f[0], - f[1]); - src_offset += group_size / sub_byte_mulitplier + sizeof(float) * 2; - dst_offset += group_size; - } - s += src_offset; + auto f = reinterpret_cast(s); + attn_dequant_u8_kernel(s + 2 * sizeof(float), t, K, f[0], f[1]); + s += src_stride + 2 * sizeof(float); t += src_stride; } - pack_32NxK::value>(dst, - tmp, - reinterpret_cast(0), - N, - K, - dst_stride, - src_stride, - group_size); + pack_32NxK(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); } # endif -template ::value == ov::element::f32, bool>::type = true> -static void pack_32NxK(TDST* dst, - void* src, - TDST* tmp, - const size_t N, - const size_t K, - const size_t dst_stride, - const size_t src_stride, - const size_t group_size) { +template +static void pack_32NxK(float* dst, T* src, float* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { // never called OPENVINO_THROW("pack_32NxK: should not be called."); } -template +template struct MHAHelper { // initialize once size_t _H; @@ -1162,8 +831,6 @@ struct MHAHelper { size_t _nthr; size_t _sliding_window; float _d_scale; - size_t _key_group_size = 0; - size_t _value_group_size = 0; PlainTensor _weight; // [nthr, H, 32, rnd_up(kv_len, block_size)], shared by first and second loop along bh PlainTensor _output; // [nthr, 32, H, S], shared by first and second loop along bh @@ -1193,12 +860,6 @@ struct MHAHelper { _weight.resize({size_t{1}, size_t{1}, size_t{1}, size_t{1}}); } - explicit MHAHelper(size_t key_group_size, size_t value_group_size) - : _key_group_size(key_group_size), - _value_group_size(value_group_size) { - _weight.resize({size_t{1}, size_t{1}, size_t{1}, size_t{1}}); - } - void init(size_t H, size_t S, size_t SV, @@ -1286,11 +947,11 @@ struct MHAHelper { if ((S % 32 == 0) && (block_size % 16 == 0) && (S <= 32 * 6)) { if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::amx_bf16) && precision_of::value == ov::element::bf16 && - precision_of::value == ov::element::bf16 && VALUE_PREC == ov::element::bf16) { + precision_of::value == ov::element::bf16) { _fastpath_valid_prec = ov::element::bf16; } else if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::amx_fp16) && precision_of::value == ov::element::f16 && - precision_of::value == ov::element::f16 && VALUE_PREC == ov::element::f16) { + precision_of::value == ov::element::f16) { _fastpath_valid_prec = ov::element::f16; } } @@ -1359,7 +1020,7 @@ struct MHAHelper { auto q_end = std::min(q_start + _block_size, q_len); auto q_cnt = q_end - q_start; constexpr bool q_is_xf16 = one_of(precision_of::value, ov::element::bf16, ov::element::f16); - constexpr bool q_cache_is_same = precision_of::value == VALUE_PREC; + constexpr bool q_cache_is_same = precision_of::value == precision_of::value; auto cur_kv_len_blocks = div_up(cur_kv_len, _block_size); for (size_t h = hq_beg; h < hq_end; h++) { auto* q_ptr = query.ptr(h, q_start, 0); @@ -1503,7 +1164,7 @@ struct MHAHelper { for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { (*_gemv)(query.ptr(h, pq), - present_key.ptr(block_number, hk), + present_key.ptr(block_number, hk), _weight.ptr(ithr, h, pq) + pk); } } @@ -1515,11 +1176,10 @@ struct MHAHelper { for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { dot_product_block(query.ptr(h, pq), - present_key.ptr(block_number, hk), + present_key.ptr(block_number, hk), _weight.ptr(ithr, h, pq) + pk, _S, - std::min(_block_size, cur_kv_len - pk), - _key_group_size); + std::min(_block_size, cur_kv_len - pk)); } } } @@ -1557,20 +1217,14 @@ struct MHAHelper { memset(_output.ptr(ithr), 0, q_len * _H * _SV * sizeof(float)); for (size_t pv = 0, i = 0; pv < cur_kv_len; pv += _block_size, i++) { auto block_number = block_table[i]; + auto* v = present_value.ptr(block_number, hk); for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { - auto sub_byte_multiplier = get_sub_byte_multiplier(present_value.get_precision()); - size_t v_stride = (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) * - present_value.get_precision().size() / sub_byte_multiplier; - auto* v_ptr = reinterpret_cast::value_type*>( - present_value.m_ptr.get() + v_stride); - attn_acc_value_block::value_type, VALUE_PREC>( - _output.ptr(ithr, pq, h), - _weight.ptr(ithr, h, pq) + pv, - v_ptr, - _SV, - std::min(_block_size, cur_kv_len - pv), - _value_group_size); + attn_acc_value_block(_output.ptr(ithr, pq, h), + _weight.ptr(ithr, h, pq) + pv, + v, + _SV, + std::min(_block_size, cur_kv_len - pv)); } } } @@ -1647,7 +1301,7 @@ struct MHAHelper { for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { (*_gemv)(query.ptr(b, h, pq), - present_key.ptr(block_number, hk), + present_key.ptr(block_number, hk), _weight_bhl.ptr(b, h, pq) + pk); } } @@ -1656,11 +1310,10 @@ struct MHAHelper { for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { dot_product_block(query.ptr(b, h, pq), - present_key.ptr(block_number, hk), + present_key.ptr(block_number, hk), _weight_bhl.ptr(b, h, pq) + pk, _S, - std::min(_block_size, context_len - pk), - _key_group_size); + std::min(_block_size, context_len - pk)); } } } @@ -1727,21 +1380,14 @@ struct MHAHelper { // kv_len must be valid if (pv < context_len) { auto block_number = block_indices.ptr()[block_indices_begins.ptr()[b] + pv_in_blocks]; + auto* v = present_value.ptr(block_number, hk); for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { - auto sub_byte_multiplier = get_sub_byte_multiplier(present_value.get_precision()); - size_t v_stride = - (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) * - present_value.get_precision().size() / sub_byte_multiplier; - auto* v_ptr = reinterpret_cast::value_type*>( - present_value.m_ptr.get() + v_stride); - attn_acc_value_block::value_type, VALUE_PREC>( - _output_bhl.ptr(ithr, b, pq, h), - _weight_bhl.ptr(b, h, pq) + pv, - v_ptr, - _SV, - std::min(_block_size, context_len - pv), - _value_group_size); + attn_acc_value_block(_output_bhl.ptr(ithr, b, pq, h), + _weight_bhl.ptr(b, h, pq) + pv, + v, + _SV, + std::min(_block_size, context_len - pv)); } } } @@ -1762,9 +1408,9 @@ struct MHAHelper { } }; -template +template struct MHA { - MHAHelper& _helper; + MHAHelper& _helper; struct AttnWorkItem { int32_t batch_in_reorder; // which batch in reorder buffer will be used int32_t batch_in_seq; // batch idx in sequence @@ -1864,7 +1510,7 @@ struct MHA { WorkItems _workitems; - MHA(MHAHelper& helper) : _helper(helper) {} + MHA(MHAHelper& helper) : _helper(helper) {} // one loop to handle first and second tokens void exec_loop_mixed(const PlainTensor& q, @@ -1881,7 +1527,7 @@ struct MHA { auto Hk = v_cache.m_dims[1]; constexpr bool q_is_xf16 = one_of(precision_of::value, ov::element::bf16, ov::element::f16); - constexpr bool q_cache_is_same = precision_of::value == VALUE_PREC; + constexpr bool q_cache_is_same = precision_of::value == precision_of::value; auto attn_work_count = _workitems.attn_work_size(); auto reorder_work_count = _workitems.reorder_work_size(); @@ -1901,45 +1547,30 @@ struct MHA { return; auto ithr = parallel_get_thread_num(); - auto* k_ptr = k_cache.ptr(block_number, hk); - - transpose_16NxK::value>( - _helper._qk_scratch_b.template ptr(batch_in_reorder, kv_block, hk), - k_ptr, - _helper._output.template ptr(ithr), - _helper._block_size, - _helper._S, - _helper._block_size, - _helper._S, - _helper._key_group_size); - + auto* k_ptr = k_cache.ptr(block_number, hk); + auto* v_ptr = v_cache.ptr(block_number, hk); + transpose_16NxK(_helper._qk_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + k_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._S, + _helper._block_size, + _helper._S); if (q_is_xf16) { - auto sub_byte_multiplier = get_sub_byte_multiplier(v_cache.get_precision()); - size_t v_stride = (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) * - v_cache.get_precision().size() / sub_byte_multiplier; - auto* v_ptr = v_cache.m_ptr.get() + v_stride; - pack_32NxK( - _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), - v_ptr, - _helper._output.template ptr(ithr), - _helper._block_size, - _helper._SV, - rnd_up(_helper._SV, _helper._block_size), - _helper._SV, - _helper._value_group_size); + pack_32NxK(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + v_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._SV, + rnd_up(_helper._SV, _helper._block_size), + _helper._SV); } else { // need to decompress if (!q_cache_is_same) { - auto sub_byte_multiplier = get_sub_byte_multiplier(v_cache.get_precision()); - size_t v_stride = (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) * - v_cache.get_precision().size() / sub_byte_multiplier; - auto* v_ptr = v_cache.m_ptr.get() + v_stride; - dequant( - _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), - v_ptr, - _helper._block_size, - _helper._SV, - _helper._value_group_size); + dequant(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + v_ptr, + _helper._block_size, + _helper._SV); } } }); @@ -2090,18 +1721,14 @@ struct MHA { } }; -template +template struct AttentionExecutor : public PagedAttentionExecutor { - MHAHelper _helper; - MHA _kernel; + MHAHelper _helper; + MHA _kernel; PlainTensor _slot_mapping; AttentionExecutor() : _kernel(_helper) {} - explicit AttentionExecutor(size_t key_group_size, size_t value_group_size) - : _helper(MHAHelper(key_group_size, value_group_size)), - _kernel(_helper) {} - void init(const std::vector& inputs, const std::vector& outputs, PlainTensor& q, @@ -2142,22 +1769,8 @@ struct AttentionExecutor : public PagedAttentionExecutor { // The layout for per token per head for u8 kv cache: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The actual size needs to deduct scale and zeropoint. - const size_t key_sub_byte_multiplier = get_sub_byte_multiplier(k_cache.get_precision()); - const size_t value_sub_byte_multiplier = get_sub_byte_multiplier(v_cache.get_precision()); - const size_t key_params_size = sizeof(float) * 2 * key_sub_byte_multiplier; - // u4 needs scale + zp. s4 needs scale. - const size_t param_size = - one_of(v_cache.get_precision(), ov::element::u4, ov::element::u8) ? sizeof(float) * 2 : sizeof(float); - const size_t value_params_size = param_size * value_sub_byte_multiplier; - size_t key_group_num = - _helper._key_group_size ? k_cache.size(3) / (_helper._key_group_size + key_params_size) : 1; - size_t value_group_num = - _helper._value_group_size ? v_cache.size(3) / (_helper._value_group_size + value_params_size) : 1; - auto S = k_cache.size(3) - (k_cache.get_precision().is_real() ? 0 : key_params_size * key_group_num); - auto SV = v_cache.size(3) - (v_cache.get_precision().is_real() ? 0 : value_params_size * value_group_num); - // revise group_size if it's zero. - _helper._key_group_size = _helper._key_group_size ? _helper._key_group_size : S; - _helper._value_group_size = _helper._value_group_size ? _helper._value_group_size : SV; + auto S = k_cache.size(3) - (k_cache.m_dt == ov::element::Type_t::u8 ? sizeof(float) * 2 : 0); + auto SV = v_cache.size(3) - (k_cache.m_dt == ov::element::Type_t::u8 ? sizeof(float) * 2 : 0); auto block_size = k_cache.size(2); auto H = q.size(1) / S; auto h_each_group_len = 1; @@ -2172,10 +1785,9 @@ struct AttentionExecutor : public PagedAttentionExecutor { q = q.reshape({B_token, H, 1, S}); k = k.reshape({B_token, Hk, 1, S}); v = v.reshape({B_token, Hk, 1, SV}); - if (k_cache.m_dt == ov::element::Type_t::u8) { - k_cache.assert_dims({0, Hk, block_size, S + key_params_size * key_group_num}, true); - v_cache.assert_dims({k_cache.m_dims[0], Hk, block_size, SV + value_params_size * value_group_num}); + k_cache.assert_dims({0, Hk, block_size, S + sizeof(float) * 2}, true); + v_cache.assert_dims({k_cache.m_dims[0], Hk, block_size, SV + sizeof(float) * 2}); } else { k_cache.assert_dims({0, Hk, block_size, S}, true); v_cache.assert_dims({k_cache.m_dims[0], Hk, block_size, SV}); @@ -2225,13 +1837,7 @@ struct AttentionExecutor : public PagedAttentionExecutor { } if (k_cache.m_dt == ov::element::Type_t::u8) { - paged_attn_quantkv(k, - v, - k_cache, - v_cache, - _slot_mapping, - _helper._key_group_size, - _helper._value_group_size); + paged_attn_quantkv(k, v, k_cache, v_cache, _slot_mapping); } else { paged_attn_memcpy(k, v, k_cache, v_cache, _slot_mapping); } @@ -2281,79 +1887,40 @@ struct AttentionExecutor : public PagedAttentionExecutor { }; #endif -std::shared_ptr make_pa_executor(ov::element::Type data_type, - ov::element::Type key_cache_type, - ov::element::Type value_cache_type, - size_t key_group_size, - size_t value_group_size) { +std::shared_ptr make_pa_executor(ov::element::Type data_type, ov::element::Type kvcache_type) { std::shared_ptr executor; #ifdef OPENVINO_ARCH_X86_64 if (data_type == ov::element::bf16) { # if defined(HAVE_AVX512F) - if (key_cache_type == ov::element::u8) { - if (value_cache_type == ov::element::u4) { - executor = - std::make_shared>(key_group_size, - value_group_size); - } else if (value_cache_type == ov::element::u8) { - executor = - std::make_shared>(key_group_size, - value_group_size); - } else { - OPENVINO_THROW("make_pa_executor: key_cache_type u8 with value_cache_type ", - value_cache_type.to_string(), - " is not support"); - } - + if (kvcache_type == ov::element::u8) { + executor = std::make_shared>(); } else { - OPENVINO_ASSERT(key_cache_type == ov::element::bf16, "expect kvcache type bf16, current: ", key_cache_type); - executor = std::make_shared>(); + OPENVINO_ASSERT(kvcache_type == ov::element::bf16, "expect kvcache type bf16, current: ", kvcache_type); + executor = std::make_shared>(); } # else OPENVINO_THROW("make_pa_executor: bf16 needs avx512+ hardware."); # endif } else if (data_type == ov::element::f16) { # if defined(HAVE_AVX512F) - if (key_cache_type == ov::element::u8) { - if (value_cache_type == ov::element::u4) { - executor = std::make_shared>(key_group_size, - value_group_size); - } else if (value_cache_type == ov::element::u8) { - executor = std::make_shared>(key_group_size, - value_group_size); - } else { - OPENVINO_THROW("make_pa_executor: key_cache_type u8 with value_cache_type ", - value_cache_type.to_string(), - " is not support"); - } + if (kvcache_type == ov::element::u8) { + executor = std::make_shared>(); } else { - OPENVINO_ASSERT(key_cache_type == ov::element::f16, "expect kvcache type f16, current: ", key_cache_type); - executor = std::make_shared>(); + OPENVINO_ASSERT(kvcache_type == ov::element::f16, "expect kvcache type f16, current: ", kvcache_type); + executor = std::make_shared>(); } # else OPENVINO_THROW("make_pa_executor: f16 needs avx512+ hardware."); # endif } else if (data_type == ov::element::f32) { - if (key_cache_type == ov::element::u8) { - if (value_cache_type == ov::element::u4) { - executor = std::make_shared>(key_group_size, - value_group_size); - } else if (value_cache_type == ov::element::u8) { - executor = std::make_shared>(key_group_size, - value_group_size); - } else { - OPENVINO_THROW("make_pa_executor: key_cache_type u8 with value_cache_type ", - value_cache_type.to_string(), - " is not support"); - } - } else if (key_cache_type == ov::element::f16) { - executor = std::make_shared>(key_group_size, - value_group_size); + if (kvcache_type == ov::element::u8) { + executor = std::make_shared>(); + } else if (kvcache_type == ov::element::f16) { + executor = std::make_shared>(); } else { - OPENVINO_ASSERT(key_cache_type == ov::element::f32, "expect kvcache type f32, current: ", key_cache_type); - executor = - std::make_shared>(key_group_size, value_group_size); + OPENVINO_ASSERT(kvcache_type == ov::element::f32, "expect kvcache type f32, current: ", kvcache_type); + executor = std::make_shared>(); } } else { OPENVINO_THROW("make_pa_executor: unsupported precision: ", data_type); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp index 64e4eefc3b760d..d28125b3898460 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp @@ -17,11 +17,7 @@ namespace Extensions { namespace Cpu { namespace XARCH { -std::shared_ptr make_pa_executor(ov::element::Type data_type, - ov::element::Type key_cache_type, - ov::element::Type value_cache_type, - size_t key_group_size, - size_t value_group_size); +std::shared_ptr make_pa_executor(ov::element::Type data_type, ov::element::Type kvcache_type); } // namespace XARCH } // namespace Cpu diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index 54aa80e9dff7c0..b51b2b3d8029a9 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -84,14 +84,13 @@ void PagedAttention::initSupportedPrimitiveDescriptors() { OPENVINO_ASSERT(orgInputNumber == 13, "The input number of PagedAttention should be 13."); // kvcache, float, [] - auto past_key_input_mem_precision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); - auto past_value_input_mem_precision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_VCACHE); + auto past_kv_input_mem_precision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); config.inConfs[PagedAttentionExecutor::ID_KCACHE].setMemDesc( creatorsMap.at(LayoutType::ncsp) - ->createSharedDesc(past_key_input_mem_precision, getInputShapeAtPort(PagedAttentionExecutor::ID_KCACHE))); + ->createSharedDesc(past_kv_input_mem_precision, getInputShapeAtPort(PagedAttentionExecutor::ID_KCACHE))); config.inConfs[PagedAttentionExecutor::ID_VCACHE].setMemDesc( creatorsMap.at(LayoutType::ncsp) - ->createSharedDesc(past_value_input_mem_precision, getInputShapeAtPort(PagedAttentionExecutor::ID_VCACHE))); + ->createSharedDesc(past_kv_input_mem_precision, getInputShapeAtPort(PagedAttentionExecutor::ID_VCACHE))); // past_lens, int, [b_seq] config.inConfs[PagedAttentionExecutor::ID_PAST_LENS].setMemDesc( creatorsMap.at(LayoutType::ncsp) @@ -141,14 +140,8 @@ void PagedAttention::createPrimitive() { auto builder = [&](const PagedAttentionKey& key) -> std::shared_ptr { #ifdef OPENVINO_ARCH_X86_64 - // Since we are quantize only last dim it's safe to use the last dim of KV. - auto kCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); - auto vCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_VCACHE); - const auto& cpuConfig = context->getConfig(); - - size_t key_group_size = cpuConfig.keyCacheGroupSize; - size_t value_group_size = cpuConfig.valueCacheGroupSize; - return make_pa_executor(rtPrecision, kCachePrecision, vCachePrecision, key_group_size, value_group_size); + auto kvCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); + return make_pa_executor(rtPrecision, kvCachePrecision); #else return nullptr; #endif @@ -209,20 +202,6 @@ void PagedAttention::execute(dnnl::stream strm) { bool PagedAttention::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { - auto vCachePrecision = op->get_input_element_type(PagedAttentionExecutor::ID_VCACHE); - auto kCachePrecision = op->get_input_element_type(PagedAttentionExecutor::ID_KCACHE); - if (one_of(vCachePrecision, - ov::element::u4, - ov::element::u8, - ov::element::f32, - ov::element::f16, - ov::element::bf16)) { - if (!one_of(kCachePrecision, ov::element::u8, ov::element::f16, ov::element::f32, ov::element::bf16)) { - errorMessage = "PageAttn key value cache compression doesn't support key cache prec " + - kCachePrecision.to_string() + " value cache prec " + vCachePrecision.to_string(); - return false; - } - } int orgInput = static_cast(op->get_input_size()); if (op->get_type_name() == std::string("PagedAttentionExtension") && orgInput == PagedAttentionExecutor::ID_SLIDING_WINDOW + 1) { diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index a4af32ce07046a..71d39a1fce5ba8 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -436,14 +436,12 @@ struct MHAKernel { } T* v_ptr = is_xf16 ? &wv_scratch_b.at({b, h / h_each_group_len, 0}) : &present_value.at({b, h / h_each_group_len, 0, 0}); - wv_gemm_ptr->executeGemm(m_cntget_scratch_a_size()> 0 - ? &wv_scratch_a.at({tid, 0}) - : nullptr); + wv_gemm_ptr->executeGemm(m_cnt < m_block_size, + w_ptr, + v_ptr, + fp32_out_ptr, + wsp.data() + tid * wsp_size_per_thread, + wv_scratch_a ? &wv_scratch_a.at({tid, 0}) : nullptr); if (is_xf16) { if (has_out_transpose) { attn_memcpy2d_kernel(&fp32_out.at({b, m_start, h, 0}), @@ -1061,26 +1059,6 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptrgetConfig(); - const auto& keyCachePrecision = cpuConfig.keyCachePrecision; - const auto& valueCachePrecision = cpuConfig.valueCachePrecision; - const auto keyDims = getInputShapeAtPort(1).getDims(); - const auto valueDims = getInputShapeAtPort(2).getDims(); - const auto keyS = *(keyDims.end() - 1); - const auto valueS = *(valueDims.end() - 1); - OPENVINO_ASSERT(valueCachePrecision == keyCachePrecision, - "CPU: SDPA node only supports same key/value cache precision"); - OPENVINO_ASSERT(one_of(keyCachePrecision, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8), - "CPU: SDPA only supports key/value cache precision f32, f16, bf16, u8 but gets ", - keyCachePrecision); - m_key_quant_param.groupSize = (cpuConfig.keyCacheGroupSize == 0 || keyS % cpuConfig.keyCacheGroupSize != 0) - ? keyS - : cpuConfig.keyCacheGroupSize; - m_key_quant_param.precision = keyCachePrecision; - m_value_quant_param.groupSize = (cpuConfig.valueCacheGroupSize == 0 || valueS % cpuConfig.valueCacheGroupSize != 0) - ? valueS - : cpuConfig.valueCacheGroupSize; - m_key_quant_param.precision = valueCachePrecision; if (const auto node = std::dynamic_pointer_cast(op)) { m_config.config.is_causal = node->get_causal(); @@ -1855,16 +1833,12 @@ void ScaledDotProductAttention::updatePastkv(const MemoryPtr& mem_cur_k, const M ov::element::Type ScaledDotProductAttention::getKVCachePrecision() { ov::element::Type kvcache_precision; - // TODO: SDPA only supports same key/value cache precision. auto rtPrecision = getRuntimePrecision(); - auto keyCachePrecisionHint = context->getConfig().keyCachePrecision; - auto valueCachePrecisionHint = context->getConfig().valueCachePrecision; + auto kvCachePrecisionHint = context->getConfig().kvCachePrecision; bool enableKVCacheFP16 = m_config.config.fuse_concat && mayiuse(cpu_isa_t::avx2) && - rtPrecision != ov::element::bf16 && - (keyCachePrecisionHint == ov::element::f16 && valueCachePrecisionHint == ov::element::f16); + rtPrecision != ov::element::bf16 && kvCachePrecisionHint == ov::element::f16; kvcache_precision = enableKVCacheFP16 ? ov::element::f16 : rtPrecision; - bool use_int8_kv_cache_precision = - (keyCachePrecisionHint == ov::element::u8 && valueCachePrecisionHint == ov::element::u8); + bool use_int8_kv_cache_precision = kvCachePrecisionHint == ov::element::u8; if (use_int8_kv_cache_precision) kvcache_precision = ov::element::u8; else diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.h b/src/plugins/intel_cpu/src/nodes/scaled_attn.h index 2917342314fafd..21b9056ba9517c 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.h +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.h @@ -47,10 +47,7 @@ class ScaledDotProductAttention : public Node { real_order = {permute_axes[2], permute_axes[0], permute_axes[1], permute_axes[3]}; return real_order; } - struct SDPAQuantParam { - ov::element::Type precision = ov::element::undefined; - size_t groupSize = 0; - }; + ov::element::Type getKVCachePrecision(); private: @@ -89,8 +86,6 @@ class ScaledDotProductAttention : public Node { // (0, 1, 2, 3) for BHLS // (2, 0, 1, 3) for LBHS std::vector m_kvstate_layout = {2, 0, 1, 3}; - SDPAQuantParam m_key_quant_param; - SDPAQuantParam m_value_quant_param; }; } // namespace node diff --git a/src/plugins/intel_cpu/src/plugin.cpp b/src/plugins/intel_cpu/src/plugin.cpp index 103c33b6c00be4..c0d8d655a753c8 100644 --- a/src/plugins/intel_cpu/src/plugin.cpp +++ b/src/plugins/intel_cpu/src/plugin.cpp @@ -377,14 +377,6 @@ ov::Any Plugin::get_property(const std::string& name, const ov::AnyMap& options) engConfig.fcDynamicQuantizationGroupSize); } else if (name == ov::hint::kv_cache_precision) { return decltype(ov::hint::kv_cache_precision)::value_type(engConfig.kvCachePrecision); - } else if (name == ov::key_cache_precision) { - return decltype(ov::key_cache_precision)::value_type(engConfig.keyCachePrecision); - } else if (name == ov::value_cache_precision) { - return decltype(ov::value_cache_precision)::value_type(engConfig.valueCachePrecision); - } else if (name == ov::key_cache_group_size) { - return decltype(ov::key_cache_group_size)::value_type(engConfig.keyCacheGroupSize); - } else if (name == ov::value_cache_group_size) { - return decltype(ov::value_cache_group_size)::value_type(engConfig.valueCacheGroupSize); } return get_ro_property(name, options); } @@ -428,10 +420,6 @@ ov::Any Plugin::get_ro_property(const std::string& name, const ov::AnyMap& optio RW_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RW_property(ov::hint::dynamic_quantization_group_size.name()), RW_property(ov::hint::kv_cache_precision.name()), - RW_property(ov::key_cache_precision.name()), - RW_property(ov::value_cache_precision.name()), - RW_property(ov::key_cache_group_size.name()), - RW_property(ov::value_cache_group_size.name()), }; std::vector supportedProperties; diff --git a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp index 6655a2a5e7d48d..073faba7f8d96f 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp @@ -2,15 +2,14 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "openvino/runtime/properties.hpp" - #include -#include "openvino/runtime/compiled_model.hpp" +#include "utils/properties_test.hpp" +#include "openvino/runtime/system_conf.hpp" #include "openvino/runtime/core.hpp" +#include "openvino/runtime/compiled_model.hpp" +#include "openvino/runtime/properties.hpp" #include "openvino/runtime/intel_cpu/properties.hpp" -#include "openvino/runtime/system_conf.hpp" -#include "utils/properties_test.hpp" namespace { @@ -41,10 +40,6 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkSupportedPropertiesAreAvailable RO_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RO_property(ov::hint::dynamic_quantization_group_size.name()), RO_property(ov::hint::kv_cache_precision.name()), - RO_property(ov::key_cache_precision.name()), - RO_property(ov::value_cache_precision.name()), - RO_property(ov::key_cache_group_size.name()), - RO_property(ov::value_cache_group_size.name()), }; ov::Core ie; @@ -88,7 +83,7 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkSetROPropertiesThrow) { TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreStreamsHasHigherPriorityThanThroughputHint) { ov::Core ie; - int32_t streams = 1; // throughput hint should apply higher number of streams + int32_t streams = 1; // throughput hint should apply higher number of streams int32_t value = 0; OV_ASSERT_NO_THROW(ie.set_property(deviceName, ov::num_streams(streams))); @@ -101,7 +96,7 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreStreamsHasHigherPriori TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreStreamsHasHigherPriorityThanLatencyHint) { ov::Core ie; - int32_t streams = ov::get_number_of_cpu_cores(); // latency hint should apply lower number of streams + int32_t streams = ov::get_number_of_cpu_cores(); // latency hint should apply lower number of streams int32_t value = 0; OV_ASSERT_NO_THROW(ie.set_property(deviceName, ov::num_streams(streams))); @@ -114,7 +109,7 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreStreamsHasHigherPriori TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelStreamsHasHigherPriorityThanLatencyHint) { ov::Core ie; - int32_t streams = ov::get_number_of_cpu_cores(); // latency hint should apply lower number of streams + int32_t streams = ov::get_number_of_cpu_cores(); // latency hint should apply lower number of streams int32_t value = 0; OV_ASSERT_NO_THROW(ie.set_property(deviceName, ov::hint::performance_mode(ov::hint::PerformanceMode::LATENCY))); @@ -129,7 +124,7 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelStreamsHasHigherPrior TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelStreamsHasHigherPriorityThanThroughputHint) { ov::Core ie; - int32_t streams = 1; // throughput hint should apply higher number of streams + int32_t streams = 1; // throughput hint should apply higher number of streams int32_t value = 0; ov::AnyMap config; @@ -187,36 +182,6 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckKVCachePrecision) { ASSERT_EQ(kv_cache_precision_value, ov::element::f32); } -TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkFinetuneKVCachePrecision) { - ov::Core core; - - core.set_property(deviceName, ov::key_cache_precision(ov::element::f16)); - core.set_property(deviceName, ov::value_cache_precision(ov::element::u4)); - ov::CompiledModel compiledModel = core.compile_model(model, deviceName); - - auto key_cache_precision_value = ov::element::undefined; - auto value_cache_precision_value = ov::element::undefined; - OV_ASSERT_NO_THROW(key_cache_precision_value = compiledModel.get_property(ov::key_cache_precision)); - OV_ASSERT_NO_THROW(value_cache_precision_value = compiledModel.get_property(ov::value_cache_precision)); - ASSERT_EQ(key_cache_precision_value, ov::element::f16); - ASSERT_EQ(value_cache_precision_value, ov::element::u4); -} - -TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkFinetuneKVCacheGroupSize) { - ov::Core core; - - core.set_property(deviceName, ov::key_cache_group_size(32)); - core.set_property(deviceName, ov::value_cache_group_size(16)); - ov::CompiledModel compiledModel = core.compile_model(model, deviceName); - - auto key_cache_group_size_value = 0; - auto value_cache_group_size_value = 0; - OV_ASSERT_NO_THROW(key_cache_group_size_value = compiledModel.get_property(ov::key_cache_group_size)); - OV_ASSERT_NO_THROW(value_cache_group_size_value = compiledModel.get_property(ov::value_cache_group_size)); - ASSERT_EQ(key_cache_group_size_value, 32); - ASSERT_EQ(value_cache_group_size_value, 16); -} - TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckAccuracyModeDynamicQuantizationGroupSize) { ov::Core core; @@ -260,8 +225,7 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckExecutionModeIsAvailableIn ASSERT_FALSE(model_exec_mode_it->is_mutable()); } -TEST_F(OVClassConfigTestCPU, - smoke_CpuExecNetworkCheckModelInferencePrecisionHasHigherPriorityThanCoreInferencePrecision) { +TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelInferencePrecisionHasHigherPriorityThanCoreInferencePrecision) { ov::Core ie; auto inference_precision_value = ov::element::undefined; @@ -275,8 +239,7 @@ TEST_F(OVClassConfigTestCPU, ASSERT_EQ(inference_precision_value, bf16_if_can_be_emulated); } -TEST_F(OVClassConfigTestCPU, - smoke_CpuExecNetworkCheckCoreInferencePrecisionHasHigherPriorityThanModelPerformanceExecutionMode) { +TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreInferencePrecisionHasHigherPriorityThanModelPerformanceExecutionMode) { ov::Core ie; auto execution_mode_value = ov::hint::ExecutionMode::ACCURACY; auto inference_precision_value = ov::element::undefined; @@ -294,8 +257,7 @@ TEST_F(OVClassConfigTestCPU, ASSERT_EQ(inference_precision_value, ov::element::f32); } -TEST_F(OVClassConfigTestCPU, - smoke_CpuExecNetworkCheckModelInferencePrecisionHasHigherPriorityThanCorePerformanceExecutionMode) { +TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelInferencePrecisionHasHigherPriorityThanCorePerformanceExecutionMode) { ov::Core ie; auto execution_mode_value = ov::hint::ExecutionMode::PERFORMANCE; auto inference_precision_value = ov::element::undefined; @@ -326,13 +288,14 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckLogLevel) { OV_ASSERT_NO_THROW(value = compiledModel.get_property(ov::log::level)); ASSERT_EQ(value.as(), ov::log::Level::NO); } - // check set and get - const std::vector logLevels = {ov::log::Level::ERR, - ov::log::Level::NO, - ov::log::Level::WARNING, - ov::log::Level::INFO, - ov::log::Level::DEBUG, - ov::log::Level::TRACE}; + //check set and get + const std::vector logLevels = { + ov::log::Level::ERR, + ov::log::Level::NO, + ov::log::Level::WARNING, + ov::log::Level::INFO, + ov::log::Level::DEBUG, + ov::log::Level::TRACE}; for (unsigned int i = 0; i < logLevels.size(); i++) { ov::Any value; @@ -367,109 +330,50 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCPURuntimOptions) { ov::Core ie; ov::Any type; ov::Any size; - ov::Any keySize; - ov::Any valueSize; - ov::Any keyCacheType; - ov::Any valueCacheType; ov::CompiledModel compiledModel; model->set_rt_info("f16", "runtime_options", ov::hint::kv_cache_precision.name()); model->set_rt_info("0", "runtime_options", ov::hint::dynamic_quantization_group_size.name()); - model->set_rt_info("32", "runtime_options", ov::key_cache_group_size.name()); - model->set_rt_info("16", "runtime_options", ov::value_cache_group_size.name()); - model->set_rt_info("u8", "runtime_options", ov::key_cache_precision.name()); - model->set_rt_info("u8", "runtime_options", ov::value_cache_precision.name()); OV_ASSERT_NO_THROW(compiledModel = ie.compile_model(model, deviceName)); OV_ASSERT_NO_THROW(type = compiledModel.get_property(ov::hint::kv_cache_precision)); OV_ASSERT_NO_THROW(size = compiledModel.get_property(ov::hint::dynamic_quantization_group_size)); - OV_ASSERT_NO_THROW(keySize = compiledModel.get_property(ov::key_cache_group_size)); - OV_ASSERT_NO_THROW(valueSize = compiledModel.get_property(ov::value_cache_group_size)); - OV_ASSERT_NO_THROW(keyCacheType = compiledModel.get_property(ov::key_cache_precision)); - OV_ASSERT_NO_THROW(valueCacheType = compiledModel.get_property(ov::value_cache_precision)); ASSERT_EQ(type.as(), ov::element::f16); ASSERT_EQ(size.as(), 0); - ASSERT_EQ(keySize.as(), 32); - ASSERT_EQ(valueSize.as(), 16); - ASSERT_EQ(keyCacheType.as(), ov::element::u8); - ASSERT_EQ(valueCacheType.as(), ov::element::u8); } TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCPURuntimOptionsWithCompileConfig) { ov::Core ie; ov::Any type; ov::Any size; - ov::Any keySize; - ov::Any valueSize; - ov::Any keyCacheType; - ov::Any valueCacheType; ov::CompiledModel compiledModel; model->set_rt_info("f16", "runtime_options", ov::hint::kv_cache_precision.name()); model->set_rt_info("0", "runtime_options", ov::hint::dynamic_quantization_group_size.name()); - model->set_rt_info("0", "runtime_options", ov::key_cache_group_size.name()); - model->set_rt_info("0", "runtime_options", ov::value_cache_group_size.name()); - model->set_rt_info("f32", "runtime_options", ov::key_cache_precision.name()); - model->set_rt_info("f32", "runtime_options", ov::value_cache_precision.name()); ov::AnyMap config; config[ov::hint::kv_cache_precision.name()] = "u8"; config[ov::hint::dynamic_quantization_group_size.name()] = "16"; - // propperty has higher priority than rt_info - config[ov::key_cache_group_size.name()] = "32"; - config[ov::value_cache_group_size.name()] = "16"; - // key/value cache prec has higher priority than kvCachePrec - config[ov::key_cache_precision.name()] = "f16"; - config[ov::value_cache_precision.name()] = "bf16"; OV_ASSERT_NO_THROW(compiledModel = ie.compile_model(model, deviceName, config)); OV_ASSERT_NO_THROW(type = compiledModel.get_property(ov::hint::kv_cache_precision)); OV_ASSERT_NO_THROW(size = compiledModel.get_property(ov::hint::dynamic_quantization_group_size)); - OV_ASSERT_NO_THROW(keySize = compiledModel.get_property(ov::key_cache_group_size)); - OV_ASSERT_NO_THROW(valueSize = compiledModel.get_property(ov::value_cache_group_size)); - OV_ASSERT_NO_THROW(keyCacheType = compiledModel.get_property(ov::key_cache_precision)); - OV_ASSERT_NO_THROW(valueCacheType = compiledModel.get_property(ov::value_cache_precision)); ASSERT_EQ(type.as(), ov::element::u8); ASSERT_EQ(size.as(), 16); - ASSERT_EQ(keySize.as(), 32); - ASSERT_EQ(valueSize.as(), 16); - ASSERT_EQ(keyCacheType.as(), ov::element::f16); - ASSERT_EQ(valueCacheType.as(), ov::element::bf16); } TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCPURuntimOptionsWithCoreProperties) { ov::Core core; ov::Any type; ov::Any size; - ov::Any keySize; - ov::Any valueSize; - ov::Any keyCacheType; - ov::Any valueCacheType; + core.set_property(deviceName, ov::hint::kv_cache_precision(ov::element::f32)); core.set_property(deviceName, ov::hint::dynamic_quantization_group_size(16)); - core.set_property(deviceName, ov::key_cache_group_size(8)); - core.set_property(deviceName, ov::value_cache_group_size(8)); - core.set_property(deviceName, ov::key_cache_precision(ov::element::f16)); - core.set_property(deviceName, ov::value_cache_precision(ov::element::bf16)); ov::CompiledModel compiledModel; model->set_rt_info("f16", "runtime_options", ov::hint::kv_cache_precision.name()); model->set_rt_info("0", "runtime_options", ov::hint::dynamic_quantization_group_size.name()); - model->set_rt_info("32", "runtime_options", ov::key_cache_group_size.name()); - model->set_rt_info("16", "runtime_options", ov::value_cache_group_size.name()); - // User's setting has higher priority than rt_info - model->set_rt_info("f32", "runtime_options", ov::key_cache_precision.name()); - model->set_rt_info("f32", "runtime_options", ov::value_cache_precision.name()); OV_ASSERT_NO_THROW(compiledModel = core.compile_model(model, deviceName)); OV_ASSERT_NO_THROW(type = compiledModel.get_property(ov::hint::kv_cache_precision)); OV_ASSERT_NO_THROW(size = compiledModel.get_property(ov::hint::dynamic_quantization_group_size)); - OV_ASSERT_NO_THROW(keySize = compiledModel.get_property(ov::key_cache_group_size)); - OV_ASSERT_NO_THROW(valueSize = compiledModel.get_property(ov::value_cache_group_size)); - OV_ASSERT_NO_THROW(keyCacheType = compiledModel.get_property(ov::key_cache_precision)); - OV_ASSERT_NO_THROW(valueCacheType = compiledModel.get_property(ov::value_cache_precision)); - ASSERT_EQ(type.as(), ov::element::f32); ASSERT_EQ(size.as(), 16); - ASSERT_EQ(keySize.as(), 8); - ASSERT_EQ(valueSize.as(), 8); - ASSERT_EQ(keyCacheType.as(), ov::element::f16); - ASSERT_EQ(valueCacheType.as(), ov::element::bf16); } -} // namespace +} // namespace diff --git a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp index 5adf6cbb125185..f72df3f58b69e5 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp @@ -55,10 +55,6 @@ TEST_F(OVClassConfigTestCPU, smoke_PluginAllSupportedPropertiesAreAvailable) { RW_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RW_property(ov::hint::dynamic_quantization_group_size.name()), RW_property(ov::hint::kv_cache_precision.name()), - RW_property(ov::key_cache_precision.name()), - RW_property(ov::value_cache_precision.name()), - RW_property(ov::key_cache_group_size.name()), - RW_property(ov::value_cache_group_size.name()), }; ov::Core ie;