From 04a7ecfadf31484aa4485d6f15cf37f26bb8df25 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Wed, 15 May 2024 16:46:14 +0800 Subject: [PATCH] [CPU] PagedAttention supports dynamic-split fuse (#24107) ### Details: - *Merge first token and second token inference into one parallel loop* - *~~Additional optimization: pre-transpose k-cache, pre-pack v-cache if needed~~* - *Additional optimization for first token: save q * k' upper triangle matrix computation and (q * k') * v lower triangle matrix computation* - *C++ pipeline can enable it: https://github.com/ilya-lavrenov/openvino.genai/pull/9* - *TODO(in another PR):* - alibi support - performance tuning - testcase ### Tickets: - *[138673](https://jira.devtools.intel.com/browse/CVS-138673)* --- .../cross_compile/cross_compiled_func.cmake | 3 +- src/plugins/intel_cpu/CMakeLists.txt | 7 + src/plugins/intel_cpu/src/cpu_types.cpp | 3 +- src/plugins/intel_cpu/src/cpu_types.h | 1 + src/plugins/intel_cpu/src/graph.cpp | 3 +- .../nodes/kernels/scaled_attn/attn_quant.cpp | 33 +- .../kernels/scaled_attn/attn_quant_kernel.hpp | 56 + .../nodes/kernels/scaled_attn/executor_pa.cpp | 1619 +++++++++++++++++ .../nodes/kernels/scaled_attn/executor_pa.hpp | 24 + .../scaled_attn/executor_pa_common.cpp | 113 ++ .../scaled_attn/executor_pa_common.hpp | 107 ++ .../kernels/scaled_attn/mha_single_token.cpp | 880 ++------- .../kernels/scaled_attn/mha_single_token.hpp | 2 - .../kernels/scaled_attn/softmax_kernel.hpp | 4 +- .../kernels/scaled_attn/transpose_kernel.hpp | 254 +++ .../src/nodes/kernels/x64/brgemm_kernel.cpp | 41 +- .../src/nodes/kernels/x64/brgemm_kernel.hpp | 4 +- .../intel_cpu/src/nodes/paged_attn.cpp | 216 +++ src/plugins/intel_cpu/src/nodes/paged_attn.h | 51 + .../intel_cpu/src/nodes/scaled_attn.cpp | 660 ++----- src/plugins/intel_cpu/src/nodes/scaled_attn.h | 18 - src/plugins/intel_cpu/src/nodes_factory.cpp | 2 + .../src/shape_inference/custom/paged_attn.cpp | 38 + .../src/shape_inference/custom/paged_attn.hpp | 24 + 24 files changed, 2815 insertions(+), 1348 deletions(-) create mode 100644 src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp create mode 100644 src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp create mode 100644 src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp create mode 100644 src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.cpp create mode 100644 src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp create mode 100644 src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/transpose_kernel.hpp create mode 100644 src/plugins/intel_cpu/src/nodes/paged_attn.cpp create mode 100644 src/plugins/intel_cpu/src/nodes/paged_attn.h create mode 100644 src/plugins/intel_cpu/src/shape_inference/custom/paged_attn.cpp create mode 100644 src/plugins/intel_cpu/src/shape_inference/custom/paged_attn.hpp diff --git a/cmake/developer_package/cross_compile/cross_compiled_func.cmake b/cmake/developer_package/cross_compile/cross_compiled_func.cmake index a83bfcffc53348..5f7b72ff47edff 100644 --- a/cmake/developer_package/cross_compile/cross_compiled_func.cmake +++ b/cmake/developer_package/cross_compile/cross_compiled_func.cmake @@ -81,7 +81,8 @@ function(cross_compiled_file TARGET) ## that is arch ID set(_arch ${_it}) if(_arch MATCHES ${_CURRENT_ARCH_FILTER}) - list(APPEND _CUR_ARCH_SET ${_arch}) + # make non/less-optimized version coming first + list(INSERT _CUR_ARCH_SET 0 ${_arch}) list(APPEND _FULL_ARCH_SET ${_arch}) endif() else() diff --git a/src/plugins/intel_cpu/CMakeLists.txt b/src/plugins/intel_cpu/CMakeLists.txt index 4c24c486671b53..57d5db7ee26b27 100644 --- a/src/plugins/intel_cpu/CMakeLists.txt +++ b/src/plugins/intel_cpu/CMakeLists.txt @@ -209,6 +209,13 @@ cross_compiled_file(${TARGET_NAME} NAME mha_single_token NAMESPACE ov::Extensions::Cpu::XARCH ) +cross_compiled_file(${TARGET_NAME} + ARCH AVX512F AVX2 ANY + src/nodes/kernels/scaled_attn/executor_pa.cpp + API src/nodes/kernels/scaled_attn/executor_pa.hpp + NAME make_pa_executor + NAMESPACE ov::Extensions::Cpu::XARCH +) cross_compiled_file(${TARGET_NAME} ARCH AVX512F AVX2 ANY src/nodes/kernels/scaled_attn/attn_memcpy.cpp diff --git a/src/plugins/intel_cpu/src/cpu_types.cpp b/src/plugins/intel_cpu/src/cpu_types.cpp index 73ad7b36b4a459..b66170e24a8558 100644 --- a/src/plugins/intel_cpu/src/cpu_types.cpp +++ b/src/plugins/intel_cpu/src/cpu_types.cpp @@ -237,7 +237,7 @@ static const TypeToNameMap& get_type_to_name_tbl() { {"Ngram", Type::Ngram}, {"ScaledDotProductAttention", Type::ScaledDotProductAttention}, {"ScaledDotProductAttentionWithKVCache", Type::ScaledDotProductAttention}, - {"PagedAttentionExtension", Type::ScaledDotProductAttention}, + {"PagedAttentionExtension", Type::PagedAttention}, {"RoPE", Type::RoPE}, {"GatherCompressed", Type::Gather}, {"CausalMaskPreprocess", Type::CausalMaskPreprocess}, @@ -358,6 +358,7 @@ std::string NameFromType(const Type type) { CASE(Unique); CASE(Ngram); CASE(ScaledDotProductAttention); + CASE(PagedAttention); CASE(RoPE); CASE(CausalMaskPreprocess); CASE(Unknown); diff --git a/src/plugins/intel_cpu/src/cpu_types.h b/src/plugins/intel_cpu/src/cpu_types.h index 45c3617f9b8e2d..fdf7bd4c379c19 100644 --- a/src/plugins/intel_cpu/src/cpu_types.h +++ b/src/plugins/intel_cpu/src/cpu_types.h @@ -118,6 +118,7 @@ enum class Type { Unique, Ngram, ScaledDotProductAttention, + PagedAttention, RoPE, CausalMaskPreprocess, }; diff --git a/src/plugins/intel_cpu/src/graph.cpp b/src/plugins/intel_cpu/src/graph.cpp index bfd6f949563a14..29c8f0882235f8 100644 --- a/src/plugins/intel_cpu/src/graph.cpp +++ b/src/plugins/intel_cpu/src/graph.cpp @@ -1811,8 +1811,7 @@ void Graph::EnforceInferencePrecision() { return true; // kvcache of PagedAttention should be written directly - if (node->getType() == Type::ScaledDotProductAttention && node->getOriginalInputsNumber() == 13 && - (inPort == 3 || inPort == 4)) + if (node->getType() == Type::PagedAttention && (inPort == 3 || inPort == 4)) return true; const auto &parent = node->getParentEdgeAt(inPort)->getParent(); /* Skip BF16 enforcement for nodes after Constant Inputs for maintaining precision for fusing. diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index 22be15c09a837d..c4c92da96193b1 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -17,6 +17,7 @@ #include "openvino/core/parallel.hpp" #include "common.hpp" #include "attn_quant.hpp" +#include "attn_quant_kernel.hpp" namespace ov { namespace Extensions { @@ -259,37 +260,7 @@ void attn_quant_u8(const float* src, uint8_t* dst, size_t n, float& scale, float } void attn_dequant_u8(const uint8_t* src, float* dst, size_t n, float scale, float zp) { - size_t i = 0; - // loadu_si128/epi64 does not support const qualifier - uint8_t* src_nc = const_cast(src); -#if defined(HAVE_AVX512F) - auto v_zp = _mm512_set1_ps(zp); - auto v_scale = _mm512_set1_ps(scale); - for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { - auto v0_128 = _mm_loadu_si128(reinterpret_cast<__m128i*>(src_nc + i)); - auto v0_512 = _mm512_cvtepu8_epi32(v0_128); - auto v0_value = _mm512_cvtepi32_ps(v0_512); - v0_value = _mm512_sub_ps(v0_value, v_zp); - auto v0_out = _mm512_mul_ps(v0_value, v_scale); - mm512_uni_storeu_ps(dst + i, v0_out); - } -#elif defined(HAVE_AVX2) - auto v_zp = _mm256_set1_ps(zp); - auto v_scale = _mm256_set1_ps(scale); - for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { - auto v0_128 = _mm_loadl_epi64(reinterpret_cast<__m128i*>(src_nc + i)); - auto v0_256 = _mm256_cvtepu8_epi32(v0_128); - auto v0_value = _mm256_cvtepi32_ps(v0_256); - v0_value = _mm256_sub_ps(v0_value, v_zp); - auto v0_out = _mm256_mul_ps(v0_value, v_scale); - mm256_uni_storeu_ps(dst + i, v0_out); - } -#endif - for (; i < n; ++i) { - float tmp = src_nc[i]; - tmp = (tmp - zp) * scale; - dst[i] = tmp; - } + attn_dequant_u8_kernel(src, dst, n, scale, zp); } } // namespace XARCH diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp new file mode 100644 index 00000000000000..4e013a004d29f9 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -0,0 +1,56 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include +#include +#include +#include "openvino/core/type/element_type.hpp" +#include "utils/plain_tensor.hpp" + +namespace ov { +namespace Extensions { +namespace Cpu { +namespace XARCH { + +template +void attn_dequant_u8_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { + size_t i = 0; + // loadu_si128/epi64 does not support const qualifier + uint8_t* src_nc = const_cast(src); +#if defined(HAVE_AVX512F) + auto v_zp = _mm512_set1_ps(zp); + auto v_scale = _mm512_set1_ps(scale); + for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { + auto v0_128 = _mm_loadu_si128(reinterpret_cast<__m128i*>(src_nc + i)); + auto v0_512 = _mm512_cvtepu8_epi32(v0_128); + auto v0_value = _mm512_cvtepi32_ps(v0_512); + v0_value = _mm512_sub_ps(v0_value, v_zp); + auto v0_out = _mm512_mul_ps(v0_value, v_scale); + mm512_uni_storeu_ps(dst + i, v0_out); + } +#elif defined(HAVE_AVX2) + auto v_zp = _mm256_set1_ps(zp); + auto v_scale = _mm256_set1_ps(scale); + for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { + auto v0_128 = _mm_loadl_epi64(reinterpret_cast<__m128i*>(src_nc + i)); + auto v0_256 = _mm256_cvtepu8_epi32(v0_128); + auto v0_value = _mm256_cvtepi32_ps(v0_256); + v0_value = _mm256_sub_ps(v0_value, v_zp); + auto v0_out = _mm256_mul_ps(v0_value, v_scale); + mm256_uni_storeu_ps(dst + i, v0_out); + } +#endif + for (; i < n; ++i) { + float tmp = src_nc[i]; + tmp = (tmp - zp) * scale; + dst[i] = tmp; + } +} + +} // namespace XARCH +} // namespace Cpu +} // namespace Extensions +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp new file mode 100644 index 00000000000000..3dce6bc8ec4be0 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -0,0 +1,1619 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include + +#include +#include +#include +#include +#include + +#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) +# include +#endif + +#include "openvino/core/type/bfloat16.hpp" +#include "openvino/core/parallel.hpp" +#include "executor_pa.hpp" +#include "executor_pa_common.hpp" +#include "common.hpp" +#include "attn_quant_kernel.hpp" +#include "softmax_kernel.hpp" +#include "transpose_kernel.hpp" +#include "utils/plain_tensor.hpp" +#include "attn_memcpy.hpp" +#include "nodes/kernels/x64/brgemm_kernel.hpp" + +namespace ov { +namespace Extensions { +namespace Cpu { +namespace XARCH { + +using namespace ov; +using namespace ov::intel_cpu; + +// currently depends on brgemm which only support x64 +#ifdef OPENVINO_ARCH_X86_64 + +#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) + +#define prefetch_bytes(bytes, sel, advance, src) { \ + auto *p = reinterpret_cast(src); \ + for (size_t i = 0; i < bytes; i += 64) \ + _mm_prefetch(p + i + advance, sel); \ +} + +#else + +#define prefetch_bytes(bytes, sel, advance, src) + +#endif + +template +void cvt_copy(TA* dst, TB* src, size_t n) { + size_t i = 0; +#if defined(HAVE_AVX512F) + for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { + auto vb = mm512_uni_loadu_ps(src + i); + mm512_uni_storeu_ps(dst + i, vb); + } +#elif defined(HAVE_AVX2) + for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { + auto vb = mm256_uni_loadu_ps(src + i); + mm256_uni_storeu_ps(dst + i, vb); + } +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template +static void attn_acc_value_block(float* out, float* weight, T* v, size_t S, size_t block_size) { +#if defined(HAVE_AVX512F) + size_t j = 0; + for (; j + 4 <= block_size; j += 4) { + auto attn_w_vec0 = _mm512_set1_ps(weight[0]); + auto attn_w_vec1 = _mm512_set1_ps(weight[1]); + auto attn_w_vec2 = _mm512_set1_ps(weight[2]); + auto attn_w_vec3 = _mm512_set1_ps(weight[3]); + size_t i = 0; + for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { + auto v_out = mm512_uni_loadu_ps(out + i); + v_out = _mm512_fmadd_ps(attn_w_vec0, mm512_uni_loadu_ps(v + i), v_out); + v_out = _mm512_fmadd_ps(attn_w_vec1, mm512_uni_loadu_ps(v + i + S), v_out); + v_out = _mm512_fmadd_ps(attn_w_vec2, mm512_uni_loadu_ps(v + i + S * 2), v_out); + v_out = _mm512_fmadd_ps(attn_w_vec3, mm512_uni_loadu_ps(v + i + S * 3), v_out); + + _mm512_storeu_ps(out + i, v_out); + } + for (; i < S; i++) { + out[i] += weight[0] * v[i]; + out[i] += weight[1] * v[i + S]; + out[i] += weight[2] * v[i + S * 2]; + out[i] += weight[3] * v[i + S * 3]; + } + v += 4 * S; + weight += 4; + } + if (j + 2 <= block_size) { + auto attn_w_vec0 = _mm512_set1_ps(weight[0]); + auto attn_w_vec1 = _mm512_set1_ps(weight[1]); + size_t i = 0; + for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { + auto v_out = mm512_uni_loadu_ps(out + i); + v_out = _mm512_fmadd_ps(attn_w_vec0, mm512_uni_loadu_ps(v + i), v_out); + v_out = _mm512_fmadd_ps(attn_w_vec1, mm512_uni_loadu_ps(v + i + S), v_out); + + _mm512_storeu_ps(out + i, v_out); + } + for (; i < S; i++) { + out[i] += weight[0] * v[i]; + out[i] += weight[1] * v[i + S]; + } + v += 2 * S; + weight += 2; + j += 2; + } + if (j < block_size) { + auto attn_w_vec0 = _mm512_set1_ps(weight[0]); + size_t i = 0; + for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { + auto v_out = mm512_uni_loadu_ps(out + i); + v_out = _mm512_fmadd_ps(attn_w_vec0, mm512_uni_loadu_ps(v + i), v_out); + + _mm512_storeu_ps(out + i, v_out); + } + for (; i < S; i++) { + out[i] += weight[0] * v[i]; + } + } + return; +#elif defined(HAVE_AVX2) + size_t j = 0; + for (; j + 4 <= block_size; j += 4) { + auto attn_w_vec0 = _mm256_set1_ps(weight[0]); + auto attn_w_vec1 = _mm256_set1_ps(weight[1]); + auto attn_w_vec2 = _mm256_set1_ps(weight[2]); + auto attn_w_vec3 = _mm256_set1_ps(weight[3]); + size_t i = 0; + for (; i + vec_len_f32_avx2 <= S; i += vec_len_f32_avx2) { + auto v_out = mm256_uni_loadu_ps(out + i); + v_out = _mm256_fmadd_ps(attn_w_vec0, mm256_uni_loadu_ps(v + i), v_out); + v_out = _mm256_fmadd_ps(attn_w_vec1, mm256_uni_loadu_ps(v + i + S), v_out); + v_out = _mm256_fmadd_ps(attn_w_vec2, mm256_uni_loadu_ps(v + i + S * 2), v_out); + v_out = _mm256_fmadd_ps(attn_w_vec3, mm256_uni_loadu_ps(v + i + S * 3), v_out); + + mm256_uni_storeu_ps(out + i, v_out); + } + for (; i < S; i++) { + out[i] += weight[0] * v[i]; + out[i] += weight[1] * v[i + S]; + out[i] += weight[2] * v[i + S * 2]; + out[i] += weight[3] * v[i + S * 3]; + } + v += 4 * S; + weight += 4; + } + if (j + 2 <= block_size) { + auto attn_w_vec0 = _mm256_set1_ps(weight[0]); + auto attn_w_vec1 = _mm256_set1_ps(weight[1]); + size_t i = 0; + for (; i + vec_len_f32_avx2 <= S; i += vec_len_f32_avx2) { + auto v_out = mm256_uni_loadu_ps(out + i); + v_out = _mm256_fmadd_ps(attn_w_vec0, mm256_uni_loadu_ps(v + i), v_out); + v_out = _mm256_fmadd_ps(attn_w_vec1, mm256_uni_loadu_ps(v + i + S), v_out); + + mm256_uni_storeu_ps(out + i, v_out); + } + for (; i < S; i++) { + out[i] += weight[0] * v[i]; + out[i] += weight[1] * v[i + S]; + } + v += 2 * S; + weight += 2; + j += 2; + } + if (j < block_size) { + auto attn_w_vec0 = _mm256_set1_ps(weight[0]); + size_t i = 0; + for (; i + vec_len_f32_avx2 <= S; i += vec_len_f32_avx2) { + auto v_out = mm256_uni_loadu_ps(out + i); + v_out = _mm256_fmadd_ps(attn_w_vec0, mm256_uni_loadu_ps(v + i), v_out); + + mm256_uni_storeu_ps(out + i, v_out); + } + for (; i < S; i++) { + out[i] += weight[0] * v[i]; + } + } + return; +#endif + for (size_t j = 0; j < block_size; j++) { + for (size_t i = 0; i < S; i++) { + out[i] += weight[j] * v[i]; + } + v += S; + } +} + +static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S, size_t block_size) { + // The layout for per token per head: + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| + // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) +#if defined(HAVE_AVX512F) + size_t j = 0; + for (; j + 4 <= block_size; j += 4) { + auto v_f0 = reinterpret_cast(v); + auto v_f1 = reinterpret_cast(v + S + 8); + auto v_f2 = reinterpret_cast(v + 2 * (S + 8)); + auto v_f3 = reinterpret_cast(v + 3 * (S + 8)); + auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); + auto attn_w_vec1 = _mm512_set1_ps(weight[1] * v_f1[0]); + auto attn_w_vec2 = _mm512_set1_ps(weight[2] * v_f2[0]); + auto attn_w_vec3 = _mm512_set1_ps(weight[3] * v_f3[0]); + auto zp0 = _mm512_set1_ps(v_f0[1]); + auto zp1 = _mm512_set1_ps(v_f1[1]); + auto zp2 = _mm512_set1_ps(v_f2[1]); + auto zp3 = _mm512_set1_ps(v_f3[1]); + size_t i = 0; + v += 8; + for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { + auto v_out = mm512_uni_loadu_ps(out + i); + auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i)))), zp0); + auto v1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i + S + 8)))), zp1); + auto v2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i + 2 * (S + 8))))), zp2); + auto v3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i + 3 * (S + 8))))), zp3); + v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); + v_out = _mm512_fmadd_ps(attn_w_vec1, v1, v_out); + v_out = _mm512_fmadd_ps(attn_w_vec2, v2, v_out); + v_out = _mm512_fmadd_ps(attn_w_vec3, v3, v_out); + + _mm512_storeu_ps(out + i, v_out); + } + for (; i < S; i++) { + out[i] += weight[0] * (v[i] - v_f0[1]) * v_f0[0]; + out[i] += weight[1] * (v[i + S + 8] - v_f1[1]) * v_f1[0]; + out[i] += weight[2] * (v[i + 2 * (S + 8)] - v_f2[1]) * v_f2[0]; + out[i] += weight[3] * (v[i + 3 * (S + 8)] - v_f3[1]) * v_f3[0]; + } + v += 4 * (S + 8) - 8; + weight += 4; + } + for (; j < block_size; j++) { + auto v_f0 = reinterpret_cast(v); + auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); + auto zp0 = _mm512_set1_ps(v_f0[1]); + size_t i = 0; + v += 8; + for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { + auto v_out = mm512_uni_loadu_ps(out + i); + auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i)))), zp0); + v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); + + _mm512_storeu_ps(out + i, v_out); + } + for (; i < S; i++) { + out[i] += weight[0] * (v[i] - v_f0[1]) * v_f0[0]; + } + v += S; + weight++; + } + return; +#elif defined(HAVE_AVX2) + size_t j = 0; + for (; j < block_size; j++) { + auto v_f0 = reinterpret_cast(v); + auto attn_w_vec0 = _mm256_set1_ps(weight[0] * v_f0[0]); + auto zp0 = _mm256_set1_ps(v_f0[1]); + size_t i = 0; + v += 8; + for (; i + vec_len_f32_avx2 <= S; i += vec_len_f32_avx2) { + auto v_out = mm256_uni_loadu_ps(out + i); + auto v0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(v + i)))), zp0); + v_out = _mm256_fmadd_ps(attn_w_vec0, v0, v_out); + + mm256_uni_storeu_ps(out + i, v_out); + } + for (; i < S; i++) { + out[i] += weight[0] * (v[i] - v_f0[1]) * v_f0[0]; + } + v += S; + weight++; + } + return; +#endif + for (size_t j = 0; j < block_size; j++) { + auto v0 = reinterpret_cast(v); + v += 8; + for (size_t i = 0; i < S; i++) { + out[i] += weight[j] * (v[i] - v0[1]) * v0[0]; + } + v += S; + } +} + +template +static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_size) { +#if defined(HAVE_AVX512F) + size_t j = 0; + for (; j + 4 <= block_size; j += 4) { + auto vsum0 = _mm512_setzero_ps(); + auto vsum1 = _mm512_setzero_ps(); + auto vsum2 = _mm512_setzero_ps(); + auto vsum3 = _mm512_setzero_ps(); + size_t i = 0; + for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { + auto va = mm512_uni_loadu_ps(a + i); + vsum0 = _mm512_fmadd_ps(va, mm512_uni_loadu_ps(b + i), vsum0); + vsum1 = _mm512_fmadd_ps(va, mm512_uni_loadu_ps(b + i + n), vsum1); + vsum2 = _mm512_fmadd_ps(va, mm512_uni_loadu_ps(b + i + 2 * n), vsum2); + vsum3 = _mm512_fmadd_ps(va, mm512_uni_loadu_ps(b + i + 3 * n), vsum3); + } + float sum0 = _mm512_reduce_add_ps(vsum0); + float sum1 = _mm512_reduce_add_ps(vsum1); + float sum2 = _mm512_reduce_add_ps(vsum2); + float sum3 = _mm512_reduce_add_ps(vsum3); + for (; i < n; i++) { + sum0 += a[i] * b[i]; + sum1 += a[i] * b[i + n]; + sum2 += a[i] * b[i + 2 * n]; + sum3 += a[i] * b[i + 3 * n]; + } + c[0] = sum0; + c[1] = sum1; + c[2] = sum2; + c[3] = sum3; + c += 4; + b += 4 * n; + } + for (; j < block_size; j++) { + auto vsum = _mm512_setzero_ps(); + size_t i = 0; + for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { + auto va = mm512_uni_loadu_ps(a + i); + vsum = _mm512_fmadd_ps(va, mm512_uni_loadu_ps(b + i), vsum); + } + float sum = _mm512_reduce_add_ps(vsum); + for (; i < n; i++) { + sum += a[i] * b[i]; + } + b += n; + *c++ = sum; + } + return; +#elif defined(HAVE_AVX2) + size_t j = 0; + for (; j + 4 <= block_size; j += 4) { + auto vsum0 = _mm256_set1_ps(0.0f); + auto vsum1 = _mm256_set1_ps(0.0f); + auto vsum2 = _mm256_set1_ps(0.0f); + auto vsum3 = _mm256_set1_ps(0.0f); + size_t i = 0; + for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { + auto va = mm256_uni_loadu_ps(a + i); + vsum0 = _mm256_fmadd_ps(va, mm256_uni_loadu_ps(b + i), vsum0); + vsum1 = _mm256_fmadd_ps(va, mm256_uni_loadu_ps(b + i + n), vsum1); + vsum2 = _mm256_fmadd_ps(va, mm256_uni_loadu_ps(b + i + 2 * n), vsum2); + vsum3 = _mm256_fmadd_ps(va, mm256_uni_loadu_ps(b + i + 3 * n), vsum3); + } + hsum(vsum0); + hsum(vsum1); + hsum(vsum2); + hsum(vsum3); + float sum0 = _mm256_cvtss_f32(vsum0); + float sum1 = _mm256_cvtss_f32(vsum1); + float sum2 = _mm256_cvtss_f32(vsum2); + float sum3 = _mm256_cvtss_f32(vsum3); + for (; i < n; i++) { + sum0 += a[i] * b[i]; + sum1 += a[i] * b[i + n]; + sum2 += a[i] * b[i + 2 * n]; + sum3 += a[i] * b[i + 3 * n]; + } + c[0] = sum0; + c[1] = sum1; + c[2] = sum2; + c[3] = sum3; + c += 4; + b += 4 * n; + } + for (; j < block_size; j++) { + auto vsum = _mm256_set1_ps(0.0f); + size_t i = 0; + for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { + auto va = mm256_uni_loadu_ps(a + i); + vsum = _mm256_fmadd_ps(va, mm256_uni_loadu_ps(b + i), vsum); + } + hsum(vsum); + float sum = _mm256_cvtss_f32(vsum); + for (; i < n; i++) { + sum += a[i] * b[i]; + } + b += n; + *c++ = sum; + } + return; +#endif + for (size_t j = 0; j < block_size; j++) { + float sum = 0; + for (size_t i = 0; i < n; i++) { + sum += a[i] * b[i]; + } + b += n; + *c++ = sum; + } +} + +template +static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t block_size) { + // The layout for per token per head: + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| + // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) +#if defined(HAVE_AVX512F) + size_t j = 0; + for (; j + 4 <= block_size; j += 4) { + auto vsum0 = _mm512_setzero_ps(); + auto vsum1 = _mm512_setzero_ps(); + auto vsum2 = _mm512_setzero_ps(); + auto vsum3 = _mm512_setzero_ps(); + auto b0 = reinterpret_cast(b); + auto b1 = reinterpret_cast(b + n + 8); + auto b2 = reinterpret_cast(b + (n + 8) * 2); + auto b3 = reinterpret_cast(b + (n + 8) * 3); + auto v_zp0 = _mm512_set1_ps(b0[1]); + auto v_zp1 = _mm512_set1_ps(b1[1]); + auto v_zp2 = _mm512_set1_ps(b2[1]); + auto v_zp3 = _mm512_set1_ps(b3[1]); + size_t i = 0; + b += 8; + for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { + auto va = mm512_uni_loadu_ps(a + i); + auto vb0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i)))), v_zp0); + auto vb1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i + n + 8)))), v_zp1); + auto vb2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i + 2 * (n + 8))))), v_zp2); + auto vb3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i + 3 * (n + 8))))), v_zp3); + + vsum0 = _mm512_fmadd_ps(va, vb0, vsum0); + vsum1 = _mm512_fmadd_ps(va, vb1, vsum1); + vsum2 = _mm512_fmadd_ps(va, vb2, vsum2); + vsum3 = _mm512_fmadd_ps(va, vb3, vsum3); + } + float sum0 = _mm512_reduce_add_ps(vsum0); + float sum1 = _mm512_reduce_add_ps(vsum1); + float sum2 = _mm512_reduce_add_ps(vsum2); + float sum3 = _mm512_reduce_add_ps(vsum3); + for (; i < n; i++) { + sum0 += a[i] * (b[i] - b0[1]); + sum1 += a[i] * (b[i + n + 8] - b1[1]); + sum2 += a[i] * (b[i + 2 * (n + 8)] - b2[1]); + sum3 += a[i] * (b[i + 3 * (n + 8)] - b3[1]); + } + c[0] = sum0 * b0[0]; + c[1] = sum1 * b1[0]; + c[2] = sum2 * b2[0]; + c[3] = sum3 * b3[0]; + c += 4; + b += 4 * (n + 8) - 8; + } + for (; j < block_size; j++) { + auto vsum = _mm512_setzero_ps(); + auto b0 = reinterpret_cast(b); + auto v_zp = _mm512_set1_ps(b0[1]); + size_t i = 0; + b += 8; + for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { + auto va = mm512_uni_loadu_ps(a + i); + auto vb = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i)))), v_zp); + vsum = _mm512_fmadd_ps(va, vb, vsum); + } + float sum = _mm512_reduce_add_ps(vsum); + for (; i < n; i++) { + sum += a[i] * (b[i] - b0[1]); + } + b += n; + *c++ = sum * b0[0]; + } + return; +#elif defined(HAVE_AVX2) + size_t j = 0; + for (; j + 4 <= block_size; j += 4) { + auto vsum0 = _mm256_setzero_ps(); + auto vsum1 = _mm256_setzero_ps(); + auto vsum2 = _mm256_setzero_ps(); + auto vsum3 = _mm256_setzero_ps(); + auto b0 = reinterpret_cast(b); + auto b1 = reinterpret_cast(b + n + 8); + auto b2 = reinterpret_cast(b + (n + 8) * 2); + auto b3 = reinterpret_cast(b + (n + 8) * 3); + auto v_zp0 = _mm256_set1_ps(b0[1]); + auto v_zp1 = _mm256_set1_ps(b1[1]); + auto v_zp2 = _mm256_set1_ps(b2[1]); + auto v_zp3 = _mm256_set1_ps(b3[1]); + size_t i = 0; + b += 8; + for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { + auto va = mm256_uni_loadu_ps(a + i); + auto vb0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i)))), v_zp0); + auto vb1 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + n + 8)))), v_zp1); + auto vb2 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + 2 * (n + 8))))), v_zp2); + auto vb3 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + 3 * (n + 8))))), v_zp3); + + vsum0 = _mm256_fmadd_ps(va, vb0, vsum0); + vsum1 = _mm256_fmadd_ps(va, vb1, vsum1); + vsum2 = _mm256_fmadd_ps(va, vb2, vsum2); + vsum3 = _mm256_fmadd_ps(va, vb3, vsum3); + } + hsum(vsum0); + hsum(vsum1); + hsum(vsum2); + hsum(vsum3); + float sum0 = _mm256_cvtss_f32(vsum0); + float sum1 = _mm256_cvtss_f32(vsum1); + float sum2 = _mm256_cvtss_f32(vsum2); + float sum3 = _mm256_cvtss_f32(vsum3); + for (; i < n; i++) { + sum0 += a[i] * (b[i] - b0[1]); + sum1 += a[i] * (b[i + n + 8] - b1[1]); + sum2 += a[i] * (b[i + 2 * (n + 8)] - b2[1]); + sum3 += a[i] * (b[i + 3 * (n + 8)] - b3[1]); + } + c[0] = sum0 * b0[0]; + c[1] = sum1 * b1[0]; + c[2] = sum2 * b2[0]; + c[3] = sum3 * b3[0]; + c += 4; + b += 4 * (n + 8) - 8; + } + for (; j < block_size; j++) { + auto vsum = _mm256_setzero_ps(); + auto b0 = reinterpret_cast(b); + auto v_zp = _mm256_set1_ps(b0[1]); + size_t i = 0; + b += 8; + for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { + auto va = mm256_uni_loadu_ps(a + i); + auto vb = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i)))), v_zp); + vsum = _mm256_fmadd_ps(va, vb, vsum); + } + hsum(vsum); + float sum = _mm256_cvtss_f32(vsum); + for (; i < n; i++) { + sum += a[i] * (b[i] - b0[1]); + } + b += n; + *c++ = sum * b0[0]; + } + return; +#endif + for (size_t j = 0; j < block_size; j++) { + float sum = 0; + auto b0 = reinterpret_cast(b); + b += 8; + for (size_t i = 0; i < n; i++) { + sum += a[i] * (b[i] - b0[1]); + } + b += n; + *c++ = sum * b0[0]; + } +} + +template +static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_stride) { + size_t i = 0; +#if defined(HAVE_AVX512F) + for (; i + vec_len_f32_avx512 <= S; i+= vec_len_f32_avx512) { + auto* src = temp + i; + auto result_vec_fp32 = _mm512_setzero_ps(); + for (size_t m = 0; m < M; m++) { + auto o_vec_fp32 = _mm512_loadu_ps(src); + result_vec_fp32 = _mm512_add_ps(result_vec_fp32, o_vec_fp32); + src += temp_stride; + } + // save to bf16 + mm512_uni_storeu_ps(dst + i, result_vec_fp32); + } +#elif defined(HAVE_AVX2) + for (; i + vec_len_f32_avx2 <= S; i += vec_len_f32_avx2) { + auto* src = temp + i; + auto result_vec_fp32 = _mm256_set1_ps(0.0f); + for (size_t m = 0; m < M; m++) { + auto o_vec_fp32 = mm256_uni_loadu_ps(src); + result_vec_fp32 = _mm256_add_ps(result_vec_fp32, o_vec_fp32); + src += temp_stride; + } + mm256_uni_storeu_ps(dst + i, result_vec_fp32); + } +#endif + for (; i < S; i++) { + auto* src = temp + i; + float sum = 0.0f; + // sum result from all threads partition + for (size_t m = 0; m < M; m++) { + sum += src[0]; + src += temp_stride; + } + dst[i] = sum; + } +} + +// N and K must be multiple of 16 +template +void transpose_16Nx16K(TDST* dst, TSRC* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { + for (size_t k = 0; k < K; k += 16) { + for (size_t n = 0; n < N; n += 16) { + transpose_16x16_kernel(dst + n, src + n * src_stride, dst_stride, src_stride); + } + + dst += 16 * dst_stride; + src += 16; + } +} + +#if defined(HAVE_AVX512F) +static void transpose_16Nx16K(ov::bfloat16* dst, ov::bfloat16* src, ov::bfloat16* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { + // will treat as uint32_t transpose + auto s = reinterpret_cast(src); + auto d = reinterpret_cast(dst); + transpose_16Nx16K(d, s, reinterpret_cast(0), N, K >> 1, dst_stride, src_stride >> 1); +} +#endif + +template +void transpose_16Nx16K(TDST* dst, uint8_t* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { + // The layout for per token per head: + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| + // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + auto s = src; + auto t = tmp; + for (size_t n = 0; n < N; n ++) { + auto f = reinterpret_cast(s); + attn_dequant_u8_kernel(s + 2 * sizeof(float), t, K, f[0], f[1]); + s += src_stride + 2 * sizeof(float); + t += src_stride; + } + transpose_16Nx16K(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); +} + +// dequant f16/u8 to float +template +static inline void dequant(T* dst, T* src, size_t N, size_t K) { + // never called + OPENVINO_THROW("dequant: should not be called."); +} + +static inline void dequant(float* dst, ov::float16* src, size_t N, size_t K) { + cvt_copy(dst, src, K * N); +} + +template +void dequant(TDST* dst, uint8_t* src, size_t N, size_t K) { + // The layout for per token per head: + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| + // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + auto s = src; + for (size_t n = 0; n < N; n ++) { + auto f = reinterpret_cast(s); + attn_dequant_u8_kernel(s + 2 * sizeof(float), dst, K, f[0], f[1]); + s += K + 2 * sizeof(float); + dst += K; + } +} + +#if defined(HAVE_AVX512F) +// pack bf16/u8 to bf16 +static void pack_32x32_kernel(ov::bfloat16* dst, ov::bfloat16* src, size_t stride) { + static const uint64_t idx[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + auto midx = _mm512_loadu_si512(idx); + for (size_t i = 0; i < 16; i++) { + auto a = _mm512_loadu_si512(src); // [a1 a2 a3 a4 | a5 a6 a7 a8] total 512-bits in 8 64bits unit + auto b = _mm512_loadu_si512(src + stride); // [b1 b2 b3 b4 | b5 b6 b7 b8] total 512-bits + a = _mm512_permutexvar_epi64(midx, a); // [a1 a5 | a2 a6 | a3 a7 | a4 a8] + b = _mm512_permutexvar_epi64(midx, b); // [b1 b5 | b2 b6 | b3 b7 | b4 b8] + auto B0 = _mm512_unpacklo_epi16(a, b); // [ a1&b1 a2&b2 a3&b3 a4&b4] for each 128-bits lane, interleave word in low 64 bits + auto B1 = _mm512_unpackhi_epi16(a, b); // [ a5&b5 a6&b6 a7&b7 a8&b8] for each 128-bits lane, interleave word in high 64 bits + _mm512_storeu_si512(dst, B0); + _mm512_storeu_si512(dst + 32, B1); + src += 2 * stride; + dst += 2 * stride; + } +} + +static void pack_32x16_kernel(ov::bfloat16* dst, ov::bfloat16* src, size_t stride) { + static const uint64_t idx[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + auto midx = _mm512_loadu_si512(idx); + for (size_t i = 0; i < 16; i++) { + auto x = _mm256_loadu_si256(reinterpret_cast<__m256i*>(src)); // [a1 a2 a3 a4] total 256-bits in 4 64bits unit + auto y = _mm256_loadu_si256(reinterpret_cast<__m256i*>(src + stride)); // [b1 b2 b3 b4] total 256-bits + auto a = _mm512_castsi256_si512(x); + auto b = _mm512_castsi256_si512(y); + a = _mm512_permutexvar_epi64(midx, a); // [a1 x | a2 x | a3 x | a4 x] + b = _mm512_permutexvar_epi64(midx, b); // [b1 x | b2 x | b3 x | b4 x] + auto B0 = _mm512_unpacklo_epi16(a, b); + _mm512_storeu_si512(dst, B0); + src += 2 * stride; + dst += 2 * stride; + } +} + +static void pack_32Nx16K(ov::bfloat16* dst, ov::bfloat16* src, ov::bfloat16* tmp, size_t N, size_t K, size_t stride) { + for (size_t n = 0; n < N; n += 32) { + size_t k = 0; + for (; k + 32 <= K; k += 32) { + pack_32x32_kernel(dst + k * 2, src + k, stride); + } + if (k < K) + pack_32x16_kernel(dst + k * 2, src + k, stride); + + dst += 32 * stride; + src += 32 * stride; + } +} + +static void pack_32Nx16K(ov::bfloat16* dst, uint8_t* src, ov::bfloat16* tmp, size_t N, size_t K, size_t stride) { + // The layout for per token per head: + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| + // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + auto s = src; + auto t = tmp; + for (size_t n = 0; n < N; n ++) { + auto f = reinterpret_cast(s); + attn_dequant_u8_kernel(s + 2 * sizeof(float), t, K, f[0], f[1]); + s += stride + 2 * sizeof(float); + t += stride; + } + pack_32Nx16K(dst, tmp, reinterpret_cast(0), N, K, stride); +} +#endif + +template +static void pack_32Nx16K(float* dst, T* src, float* tmp, size_t N, size_t K, size_t stride) { + // never called + OPENVINO_THROW("pack_32Nx16K: should not be called."); +} + +template +struct MHAHelper { + // initialize once + size_t _H; + size_t _S; + size_t _Hk; + size_t _h_each_group_len; + size_t _block_size; + size_t _nthr; + size_t _sliding_window; + float _d_scale; + + PlainTensor _weight; // [nthr, H, 32, rnd_up(kv_len, block_size)], shared by first and second loop along bh + PlainTensor _output; // [nthr, 32, H, S], shared by first and second loop along bh + PlainTensor _qk_scratch_a; // [nthr, scratch_a_size] + PlainTensor _qk_scratch_b; // [B, rnd_up(kv_len, block_size), Hk, scratch_b_size] + PlainTensor _wv_scratch_a; + PlainTensor _wv_scratch_b; + std::vector _wsp; + size_t _wsp_size_per_thread = 0; + + std::vector> _qk_gemm; + std::vector> _wv_gemm; + // will accumulate C buffer + std::vector> _wv_gemm_acc; + // second token + std::shared_ptr _gemv; + bool _fastpath_valid = false; + // second token for bhl loop + PlainTensor _weight_bhl; + PlainTensor _output_bhl; + + MHAHelper() { + _weight.resize({size_t{1}, size_t{1}, size_t{1}, size_t{1}}); + } + + void init(size_t H, size_t S, size_t Hk, size_t h_each_group_len, size_t block_size, size_t sliding_window, + float d_scale, size_t kv_len) { + // query shape: [B, H, L, S] + // present_key shape: [block, H, 32, S] + // Q*K': [M1, S] * [M2, S]' + // kernel: Q:[1~block_size, S] * K':[block_size, S]' + // aka: M:1~block_size, N:block_size, K:S + // (Q*K')*V: [M1, M2] * [M2, S] + // kernel: (Q*K'):[1~block_size, block_size] * V:[block_size, S] + // aka: M:1~block_size, N:S, K:block_size + // Because K and V are from cache, can use M2'=rnd_up(M2, block_size) to simplify logic + auto in_type = precision_of::value; + _H = H; + _S = S; + _Hk = Hk; + _h_each_group_len = h_each_group_len; + _block_size = block_size; + _nthr = static_cast(parallel_get_max_threads()); + _sliding_window = sliding_window; + _d_scale = d_scale; + + auto prev_score_stride = _weight.stride(2); + auto want_score_stride = rnd_up(kv_len, _block_size); + auto new_score_stride = std::max(prev_score_stride, want_score_stride); + // resize temporary buffers, weight.size(3) will be aligned to block_size + _weight.resize({static_cast(_nthr), H, _block_size, new_score_stride}); + _output.resize({static_cast(_nthr), _block_size, H, S}); + + // TODO: kernel supports stride + if (_qk_gemm.empty() || prev_score_stride < new_score_stride) { + _qk_gemm.resize(_block_size); + _wv_gemm.resize(_block_size); + _wv_gemm_acc.resize(_block_size); + for (size_t i = 0; i < _block_size; i++) { + _qk_gemm[i] = std::make_shared(i + 1, + _block_size, + _S, + _H * _S, + _block_size, + _weight.stride(2), + false, + in_type); + _wv_gemm[i] = std::make_shared(i + 1, + _S, + _block_size, + // if it's bf16, the stride needs double due to reuse float buffer + (in_type == ov::element::Type_t::f32 ? 1 : 2) * _weight.stride(2), + _S, + _output.stride(1), + false, + in_type); + _wv_gemm_acc[i] = std::make_shared(i + 1, + _S, + _block_size, + // if it's bf16, the stride needs double due to reuse float buffer + (in_type == ov::element::Type_t::f32 ? 1 : 2) * _weight.stride(2), + _S, + _output.stride(1), + false, + in_type, + true); + } + + // wsp is used to compute beta when K is blocked + _wsp_size_per_thread = _wv_gemm[0]->get_wsp_size(); + _wsp.resize(_nthr * _wsp_size_per_thread); + + // allocate scratch a/b, notice get_scratch_a_size/get_scratch_b_size returns in bytes + _qk_scratch_a.resize({_nthr, _qk_gemm[_block_size - 1]->get_scratch_a_size() / sizeof(DATA_TYPE)}); + _wv_scratch_a.resize({_nthr, _wv_gemm[_block_size - 1]->get_scratch_a_size() / sizeof(DATA_TYPE)}); + + _fastpath_valid = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::amx_bf16) && + (S % 32 == 0) && (block_size % 16 == 0) && (S <= 32 * 6) && precision_of::value == ov::element::bf16; + // aligned to cache line (64bytes=16*sizeof(float)) to avoid false sharing + if (_fastpath_valid && !_gemv) + _gemv = std::make_shared(static_cast(S), static_cast(block_size)); + } + } + + void init_reorder_buffers(size_t batch, size_t kv_len_in_blocks) { + _qk_scratch_b.resize({batch, kv_len_in_blocks, _Hk, _block_size * _S}); + _wv_scratch_b.resize({batch, kv_len_in_blocks, _Hk, _block_size * _S}); + } + + // compute one block(such as 32 tokens) of query in M dimension: softmax(q_block*k')*v + // all tensors such as query... have no batch dimension because batch dimension is varying + // query: [H, L, S] + // present_value: [block_number, H, 32, S] + // output_emb: [L, H * S] + // qk_scratch_b: [rnd_up(kv_len, block_size), Hk, scratch_b_size] + // wv_scratch_b: [rnd_up(kv_len, block_size), Hk, scratch_b_size] + void exec_kernel_multiple(const PlainTensor& query, const PlainTensor& present_value, const PlainTensor& output_emb, + const PlainTensor& qk_scratch_b, const PlainTensor& wv_scratch_b, + const int32_t* block_table, size_t ithr, size_t q_blk, size_t hk, size_t q_len, size_t cur_kv_len) { + auto q_start = q_blk * _block_size; + auto q_end = std::min(q_start + _block_size, q_len); + auto q_cnt = q_end - q_start; + constexpr bool q_is_bf16 = precision_of::value == ov::element::bf16; + constexpr bool q_cache_is_same = precision_of::value == precision_of::value; + auto cur_kv_len_blocks = div_up(cur_kv_len, _block_size); + for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { + auto* q_ptr = query.ptr(h, q_start, 0); + float* c_ptr = _weight.ptr(ithr, h, 0, 0); + // for each query block, loop through all key block + // for blocks: + // 1 0 0 0 ... + // 1 1 0 0 ... + // 1 1 1 0 ... + // just computing the positions of 1 should be enough + for (size_t k_blk = 0; k_blk < cur_kv_len_blocks; k_blk++) { + auto* k_ptr = qk_scratch_b.ptr(k_blk, hk); + _qk_gemm[q_cnt - 1]->executeGemm(q_cnt < _block_size, + q_ptr, + k_ptr, + c_ptr + k_blk * _block_size, + _wsp.data() + ithr * _wsp_size_per_thread, + _qk_scratch_a ? _qk_scratch_a.ptr(ithr, 0) : nullptr); + } + + for (size_t m = q_start; m < q_end; m++) { + // apply attention mask & sofmax + auto ncausal = (cur_kv_len - q_cnt + (m - q_start) + 1); + auto score = _weight.ptr(ithr, h, m - q_start); + if (_sliding_window) { + size_t start_idx = 0; + auto new_causal = ncausal; + if (ncausal > _sliding_window) { + start_idx = ncausal - static_cast(_sliding_window); + new_causal = _sliding_window; + } + attn_softmax_kernel(score + start_idx, + reinterpret_cast(score) + start_idx, + _d_scale, + nullptr, + nullptr, + nullptr, + false, + new_causal, + rnd_up(cur_kv_len, _block_size) - start_idx, + precision_of::value, + precision_of::value); + + memset(score, 0, sizeof(DATA_TYPE) * start_idx); + } else { + attn_softmax_kernel(score, + reinterpret_cast(score), + _d_scale, + nullptr, + nullptr, + nullptr, + false, + ncausal, + rnd_up(cur_kv_len, _block_size), + precision_of::value, + precision_of::value); + } + } + + // reuse float buffer, need to use float to compute offset + auto* w_ptr = reinterpret_cast(_weight.ptr(ithr, h, 0, 0)); + float* fp32_out_ptr = q_is_bf16 ? _output.ptr(ithr, 0, h, 0) : output_emb.ptr(q_start, h * _S); + + // for each weight block, loop through all value block + for (size_t v_blk = 0; v_blk < cur_kv_len_blocks; v_blk++) { + DATA_TYPE* v_ptr; + if (q_is_bf16 || !q_cache_is_same) { + v_ptr = wv_scratch_b.ptr(v_blk, hk); + } else { + v_ptr = present_value.ptr(block_table[v_blk], hk); + } + if (v_blk == 0) { + _wv_gemm[q_cnt - 1]->executeGemm(q_cnt < _block_size, + w_ptr + v_blk * _block_size, + v_ptr, + fp32_out_ptr, + _wsp.data() + ithr * _wsp_size_per_thread, + _wv_scratch_a ? _wv_scratch_a.ptr(ithr, 0) : nullptr); + } else { + _wv_gemm_acc[q_cnt - 1]->executeGemm(q_cnt < _block_size, + w_ptr + v_blk * _block_size, + v_ptr, + fp32_out_ptr, + _wsp.data() + ithr * _wsp_size_per_thread, + _wv_scratch_a ? _wv_scratch_a.ptr(ithr, 0) : nullptr); + } + } + if (q_is_bf16) { + attn_memcpy2d_kernel(_output.ptr(ithr, 0, h, 0), + output_emb.ptr(q_start, h * _S), + ov::element::f32, + ov::element::bf16, + _output.stride(1), + output_emb.stride(0), + _S, + q_cnt); + } + } + } + + // compute one token, loop along batch and head dimensions + // all tensors such as query... have no batch dimension because batch dimension is varying + // query: [H, L, S] + // present_*: [block_number, H, 32, S] + // output_emb: [L, H * S] + // weight: [nthr, H, 32, rnd_up(kv_len, block_size)] + // output: [nthr, 32, H, S] + void exec_kernel_one_bh(const PlainTensor& query, const PlainTensor& present_key, const PlainTensor& present_value, const PlainTensor& output_emb, + const int32_t* block_table, size_t ithr, size_t hk, size_t q_len, size_t cur_kv_len) { + if (_fastpath_valid) { + _gemv->tile_config(); + for (size_t pk = 0, i = 0; pk < cur_kv_len; pk += _block_size, i++) { + auto block_number = block_table[i]; + for (size_t pq = 0; pq < q_len; pq++) { + for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { + (*_gemv)(query.ptr(h, pq), present_key.ptr(block_number, hk), + _weight.ptr(ithr, h, pq) + pk); + } + } + } + _gemv->tile_release(); + } else { + for (size_t pk = 0, i = 0; pk < cur_kv_len; pk += _block_size, i++) { + auto block_number = block_table[i]; + for (size_t pq = 0; pq < q_len; pq++) { + for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { + dot_product_block(query.ptr(h, pq), present_key.ptr(block_number, hk), + _weight.ptr(ithr, h, pq) + pk, _S, std::min(_block_size, cur_kv_len - pk)); + } + } + } + } + + for (size_t pq = 0; pq < q_len; pq++) { + for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { + // apply attention mask & sofmax + attn_softmax_kernel(_weight.ptr(ithr, h, pq), + _weight.ptr(ithr, h, pq), + _d_scale, + nullptr, + nullptr, + nullptr, + false, + cur_kv_len, + cur_kv_len, + ov::element::f32, + ov::element::f32); + } + } + + memset(_output.ptr(ithr), 0, q_len * _H * _S * sizeof(float)); + for (size_t pv = 0, i = 0; pv < cur_kv_len; pv += _block_size, i++) { + auto block_number = block_table[i]; + auto* v = present_value.ptr(block_number, hk); + for (size_t pq = 0; pq < q_len; pq++) { + for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { + attn_acc_value_block(_output.ptr(ithr, pq, h), + _weight.ptr(ithr, h, pq) + pv, + v, + _S, + std::min(_block_size, cur_kv_len - pv)); + } + } + } + // convert to dst + for (size_t pq = 0; pq < q_len; pq++) + for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) + cvt_copy(output_emb.ptr(pq, h * _S), _output.ptr(ithr, pq, h), _S); + } + + // compute one token, loop along batch, head dimensions and kv_len, it's special for very long kv_len with small batch tokens. + // It will assume NO mixture execution of first and second token. + // all tensors such as query... have batch dimension which is DIFFERENT from above + // query: [B, H, L, S] + // present_*: [block_number, H, 32, S] + // output_emb: [B, L, H * S] + // 3 loops along batch, head, kv cache length dimensions + void exec_loop_bhl(const PlainTensor& query, + const PlainTensor& present_key, + const PlainTensor& present_value, + const PlainTensor& output_emb, + const PlainTensor& block_tables, + size_t max_context_len, + const PlainTensor& context_lens) { + auto B = query.size(0); + auto q_len = query.size(2); + auto kv_len_in_blocks = block_tables.m_dims[1]; + + // aligned to cache line (64bytes=16*sizeof(float)) to avoid false sharing + _weight_bhl.resize({B, _H, q_len, rnd_up(max_context_len, std::max(_block_size, size_t{16}))}); + + parallel_for3d_dynamic(B, kv_len_in_blocks, _Hk, [&](size_t b, size_t pk_in_blocks, size_t hk) { + auto context_len = static_cast(context_lens.ptr()[b]); + // kv_len must be valid + auto pk = pk_in_blocks * _block_size; + if (pk < context_len) { + auto block_number = block_tables.ptr(b)[pk_in_blocks]; + if (_fastpath_valid) { + _gemv->tile_config(); + for (size_t pq = 0; pq < q_len; pq++) { + for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { + (*_gemv)(query.ptr(b, h, pq), present_key.ptr(block_number, hk), + _weight_bhl.ptr(b, h, pq) + pk); + } + } + _gemv->tile_release(); + } else { + for (size_t pq = 0; pq < q_len; pq++) { + for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { + dot_product_block(query.ptr(b, h, pq), present_key.ptr(block_number, hk), + _weight_bhl.ptr(b, h, pq) + pk, _S, std::min(_block_size, context_len - pk)); + } + } + } + } + }); + + parallel_for3d_dynamic(B, _H, q_len, [&](size_t b, size_t h, size_t pq) { + auto cur_kv_len = static_cast(context_lens.ptr()[b]); + auto ncausal = cur_kv_len; + // apply attention mask & sofmax + attn_softmax_kernel(_weight_bhl.ptr(b, h, pq), + _weight_bhl.ptr(b, h, pq), + _d_scale, + nullptr, + nullptr, + nullptr, + false, + ncausal, + cur_kv_len, + ov::element::f32, + ov::element::f32); + }); + + // attn_w * V + _output_bhl.resize({static_cast(_nthr), B, q_len, _H, _S}); + // m_attn_w {B, H, q_len, kv_len} + parallel_nt_static(_nthr, [&](const size_t ithr, const size_t nthr) { + memset(_output_bhl.ptr(ithr, 0, 0, 0, 0), 0, _output_bhl.stride(0) * sizeof(float)); + }); + + parallel_for3d_dynamic(B, kv_len_in_blocks, _Hk, [&](size_t b, size_t pv_in_blocks, size_t hk) { + auto ithr = parallel_get_thread_num(); + auto context_len = static_cast(context_lens.ptr()[b]); + auto pv = pv_in_blocks * _block_size; + // kv_len must be valid + if (pv < context_len) { + auto block_number = block_tables.ptr(b)[pv_in_blocks]; + auto* v = present_value.ptr(block_number, hk); + for (size_t pq = 0; pq < q_len; pq++) { + for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { + attn_acc_value_block(_output_bhl.ptr(ithr, b, pq, h), + _weight_bhl.ptr(b, h, pq) + pv, + v, + _S, + std::min(_block_size, context_len - pv)); + } + } + } + }); + + parallel_for3d(B, _H, q_len, [&](size_t b, size_t h, size_t pq) { + auto* temp = _output_bhl.ptr(0, b, pq, h); + size_t temp_stride = _output_bhl.stride(0); + auto* dst = output_emb.ptr(b, pq, h * _S); + attn_reduce(dst, temp, _nthr, _S, temp_stride); + }); + } +}; + +template +struct MHAMultiple { + MHAHelper& _helper; + + MHAMultiple(MHAHelper& helper) : _helper(helper) {} + + void operator()(PlainTensor& query, + PlainTensor& present_key, + PlainTensor& present_value, + PlainTensor& output_emb, + const PlainTensor& block_tables, + size_t max_context_len, + const PlainTensor& context_lens) { + auto B = query.m_dims[0]; + auto Hk = present_value.m_dims[1]; + constexpr bool q_is_bf16 = precision_of::value == ov::element::bf16; + constexpr bool q_cache_is_same = precision_of::value == precision_of::value; + + // buffer for transpose and repack + _helper.init_reorder_buffers(B, block_tables.m_dims[1]); + + // packed k, v + parallel_for3d_dynamic(B, block_tables.m_dims[1], Hk, [&](size_t b, size_t kv_block, size_t hk) { + auto block_number = block_tables.ptr(b)[kv_block]; + if (block_number < 0) + return; + auto ithr = parallel_get_thread_num(); + auto* k_ptr = present_key.ptr(block_number, hk); + auto* v_ptr = present_value.ptr(block_number, hk); + // in AttentionExecutor::executor block_size must be multiple of 32 and head_size must be multiple of 16, + // transpose 16Nx16K/pack 32Nx16K should be enough + transpose_16Nx16K(_helper._qk_scratch_b.template ptr(b, kv_block, hk), + k_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._S, _helper._block_size, _helper._S); + if (q_is_bf16) { + pack_32Nx16K(_helper._wv_scratch_b.template ptr(b, kv_block, hk), + v_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._S, + _helper._S); + } else { + // if not bf16 and type of kvcache is not same with query, we need to decompress the kvcache. + // Currently dequant f16/u8 to f32 + if (!q_cache_is_same) { + dequant(_helper._wv_scratch_b.template ptr(b, kv_block, hk), v_ptr, _helper._block_size, _helper._S); + } + } + }); + + // query breaks to [B, H, m_blocks, block_size, S], k cache is split to [B, H, m_blocks', S, block_size] + // v cache may be [B, H, m_blocks', block_size, S] or [block_number, H, block_size, S] + // outer loop will use B, H, m_blocks to walkthrough query + parallel_for3d_dynamic(B, block_tables.m_dims[1], Hk, [&](size_t b, size_t q_blk, size_t hk) { + if (block_tables.ptr(b)[q_blk] < 0) + return; + size_t ithr = parallel_get_thread_num(); + auto cur_kv_len = static_cast(context_lens.ptr()[b]); + auto q_len = cur_kv_len; + _helper.exec_kernel_multiple(query.slice(0, b, b), present_value, output_emb.slice(0, b, b), + _helper._qk_scratch_b.slice(0, b, b), _helper._wv_scratch_b.slice(0, b, b), + block_tables.ptr(b), ithr, q_blk, hk, q_len, std::min(cur_kv_len, (q_blk + 1) * _helper._block_size)); + }); + } +}; + +// 2nd token case : only 1 token in query +template +struct MHASingle { + MHAHelper& _helper; + + MHASingle(MHAHelper& helper) : _helper(helper) {} + + // one loop along batch and head dimensions + void exec_loop_bh(PlainTensor& query, + PlainTensor& present_key, + PlainTensor& present_value, + PlainTensor& output_emb, + const PlainTensor& block_tables, + size_t max_context_len, + const PlainTensor& context_lens) { + auto B = query.m_dims[0]; + auto Hk = present_value.m_dims[1]; + parallel_for2d_dynamic(B, Hk, [&](size_t b, size_t hk) { + size_t ithr = parallel_get_thread_num(); + auto cur_kv_len = static_cast(context_lens.ptr()[b]); + auto q_len = 1ul; + _helper.exec_kernel_one_bh(query.slice(0, b, b), present_key, present_value, + output_emb.slice(0, b, b), block_tables.ptr(b), ithr, hk, q_len, cur_kv_len); + }); + } + + // Q, K, V is ready, do attention + // query [B, H, q_len, S] + // present_key [B, H, kv_len, S] stride of last dim maybe > 1 + // present_value [B, H, kv_len, S] + // output_emb [B, L1, H, S] + void operator()(PlainTensor& query, + PlainTensor& present_key, + PlainTensor& present_value, + PlainTensor& output_emb, + const PlainTensor& block_tables, + size_t max_context_len, + const PlainTensor& context_lens) { + auto B = query.size(0); + auto nthr = static_cast(parallel_get_max_threads()); + + if (B >= nthr) { + exec_loop_bh(query, present_key, present_value, output_emb, block_tables, max_context_len, context_lens); + } else { + _helper.exec_loop_bhl(query, present_key, present_value, output_emb, block_tables, max_context_len, context_lens); + } + } +}; + +template +struct MHAMixed { + MHAHelper& _helper; + struct AttnWorkItem { + int32_t batch_in_reorder; // which batch in reorder buffer will be used + int32_t batch_in_query; // batch idx in query + int32_t q_len; // current sequence length, 1 for second token, 2+ for first token + int32_t q_block_id; // block id in this seq, valid at first token + }; + struct ReorderWorkItem { + int32_t batch_in_query_last; // last batch idx in a sentence + int32_t batch_in_reorder; // which batch in reorder buffer will be used + int32_t kv_block_id; // block id in this kv cache seq + }; + struct WorkItems { + private: + std::vector attn_items; + std::vector reorder_items; + int32_t max_kv_len_in_reorder; // max kv len between first tokens + int32_t max_batch_in_reorder; + int32_t total_kv_len; + + public: + void reset(const PlainTensor& query, const PlainTensor& context_lens, const PlainTensor& subsequence_lens, size_t block_size) { + attn_items.clear(); + reorder_items.clear(); + max_kv_len_in_reorder = 0; + max_batch_in_reorder = 0; + total_kv_len = 0; + + int32_t start_batch_in_query = 0; + auto seq_cout = static_cast(subsequence_lens.m_dims[0]); + for (int32_t i = 0; i < seq_cout; i++) { + auto q_len = subsequence_lens.ptr()[i]; + // workitems for transpose, repack + // last token corresponding batch index + auto batch_in_query_last = start_batch_in_query + q_len - 1; + auto kv_len = context_lens.ptr()[batch_in_query_last]; + auto kv_len_in_block = static_cast(div_up(kv_len, block_size)); + if (q_len == 1) { + attn_items.emplace_back(AttnWorkItem{ + 0, // batch_in_reorder + start_batch_in_query, // batch_in_query + 1ull, // q_len + // kv_len in blocks, used in the sort function + kv_len_in_block - 1 + }); + start_batch_in_query++; + } else { + auto reorder_sub_work_count = kv_len_in_block; + max_kv_len_in_reorder = std::max(max_kv_len_in_reorder, kv_len); + for (int32_t block_id = 0; block_id < reorder_sub_work_count; block_id++) { + reorder_items.emplace_back(ReorderWorkItem{ + batch_in_query_last, // batch_in_query_last + max_batch_in_reorder, // batch_in_reorder + block_id // kv_block_id + }); + } + + // workitems for attention + auto attn_sub_work_count = static_cast(div_up(q_len, block_size)); + for (int32_t block_id = 0; block_id < attn_sub_work_count; block_id++) { + attn_items.emplace_back(AttnWorkItem{ + max_batch_in_reorder, // batch_in_reorder + start_batch_in_query, // batch_in_query + q_len, // q_len + block_id // q_block_id + }); + } + start_batch_in_query += q_len; + max_batch_in_reorder++; + } + total_kv_len += kv_len; + } + // std::sort(attn_items.begin(), attn_items.end(), [] (const AttnWorkItem& left, const AttnWorkItem& right) { + // // kv block number which will be acessed later + // auto left_kv_blocks = left.q_block_id; + // auto right_kv_blocks = right.q_block_id; + // return left_kv_blocks > right_kv_blocks; + // }); + } + const AttnWorkItem& get_attn_work_item(size_t idx) const { + return attn_items[idx]; + } + size_t attn_work_size() const { + return attn_items.size(); + } + const ReorderWorkItem& get_reorder_work_item(size_t idx) const { + return reorder_items[idx]; + } + size_t reorder_work_size() const { + return reorder_items.size(); + } + size_t get_reorder_max_batch_size() const { + return static_cast(max_batch_in_reorder); + } + size_t get_reorder_max_kv_len() const { + return static_cast(max_kv_len_in_reorder); + } + size_t get_total_kv_len() const { + return static_cast(total_kv_len); + } + }; + + WorkItems _workitems; + + MHAMixed(MHAHelper& helper) : _helper(helper) {} + + // one loop to handle first and second tokens + void exec_loop_mixed(const PlainTensor& query, + const PlainTensor& present_key, + const PlainTensor& present_value, + const PlainTensor& output_emb, + const PlainTensor& block_tables, + size_t max_context_len, + const PlainTensor& context_lens, + const PlainTensor& subsequence_lens) { + auto Hk = present_value.m_dims[1]; + + constexpr bool q_is_bf16 = precision_of::value == ov::element::bf16; + constexpr bool q_cache_is_same = precision_of::value == precision_of::value; + auto attn_work_count = _workitems.attn_work_size(); + auto reorder_work_count = _workitems.reorder_work_size(); + + // buffer for transpose and repack + _helper.init_reorder_buffers(_workitems.get_reorder_max_batch_size(), div_up(_workitems.get_reorder_max_kv_len(), _helper._block_size)); + + // packed k, v + parallel_for2d_dynamic(reorder_work_count, Hk, [&](size_t w, size_t hk) { + const auto& item = _workitems.get_reorder_work_item(w); + const auto batch_in_query_last = item.batch_in_query_last; + const auto batch_in_reorder = item.batch_in_reorder; + const auto kv_block = item.kv_block_id; + auto block_number = block_tables.ptr(batch_in_query_last)[kv_block]; + if (block_number < 0) + return; + + auto ithr = parallel_get_thread_num(); + auto* k_ptr = present_key.ptr(block_number, hk); + auto* v_ptr = present_value.ptr(block_number, hk); + transpose_16Nx16K(_helper._qk_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + k_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._S, _helper._block_size, _helper._S); + if (q_is_bf16) { + pack_32Nx16K(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + v_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._S, + _helper._S); + } else { + // need to decompress + if (!q_cache_is_same) { + dequant(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), v_ptr, _helper._block_size, _helper._S); + } + } + }); + + parallel_for2d_dynamic(attn_work_count, Hk, [&](size_t w, size_t hk) { + const auto& item = _workitems.get_attn_work_item(w); + const auto batch_in_query = item.batch_in_query; + const auto q_len = static_cast(item.q_len); + size_t ithr = parallel_get_thread_num(); + + if (q_len == 1) { + const auto cur_kv_len = static_cast(context_lens.ptr()[batch_in_query]); + + _helper.exec_kernel_one_bh(query.slice(0, batch_in_query, batch_in_query), present_key, present_value, + output_emb.slice(0, batch_in_query, batch_in_query), block_tables.ptr(batch_in_query), ithr, hk, 1ul, cur_kv_len); + } else { + const auto batch_in_reorder = item.batch_in_reorder; + const auto q_blk = item.q_block_id; + const auto q_start = static_cast(batch_in_query) + q_blk * _helper._block_size; + const auto q_cnt = std::min(_helper._block_size, q_len - q_blk * _helper._block_size); + const auto cur_kv_len = static_cast(context_lens.ptr()[q_start + q_cnt - 1]); + + PlainTensor sub_query; + sub_query.resize({q_len, _helper._H, _helper._S}, query.ptr(batch_in_query)); + sub_query = sub_query.permute({1, 0, 2}); + _helper.exec_kernel_multiple(sub_query, + present_value, + output_emb.slice(0, batch_in_query, batch_in_query + q_len).reshape({q_len, _helper._H * _helper._S}), + _helper._qk_scratch_b.slice(0, batch_in_reorder, batch_in_reorder), + _helper._wv_scratch_b.slice(0, batch_in_reorder, batch_in_reorder), + block_tables.ptr(q_start + q_cnt - 1), + ithr, + q_blk, + hk, + q_len, + cur_kv_len); + } + }); + } + + // Q, K, V is ready, do attention + void operator()(PlainTensor& query, + PlainTensor& present_key, + PlainTensor& present_value, + PlainTensor& output_emb, + const PlainTensor& block_tables, + size_t max_context_len, + const PlainTensor& context_lens, + const PlainTensor& subsequence_lens) { + _workitems.reset(query, context_lens, subsequence_lens, _helper._block_size); + + auto nthr = static_cast(parallel_get_max_threads()); + + if (subsequence_lens.m_dims[0] >= nthr || _workitems.get_reorder_max_batch_size() > 0) { + exec_loop_mixed(query, present_key, present_value, output_emb, block_tables, max_context_len, context_lens, subsequence_lens); + } else { + _helper.exec_loop_bhl(query, present_key, present_value, output_emb, block_tables, max_context_len, context_lens); + } + } +}; + +template +struct AttentionExecutor : public PagedAttentionExecutor { + MHAHelper _helper; + MHAMultiple _kernel_multiple; + MHASingle _kernel_single; + MHAMixed _kernel_mixed; + + AttentionExecutor() : _kernel_multiple(_helper), _kernel_single(_helper), _kernel_mixed(_helper) {} + + void execute(const std::vector& inputs, const MemoryPtr output) override { + bool is_prompt = false; + PlainTensor present_key, present_value; + PlainTensor q_input; // f32[B, H, L1, S] + PlainTensor k_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S] + PlainTensor v_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S] + PlainTensor block_tables; // i32[B, max_kvLen] + PlainTensor context_lens; + PlainTensor output_emb(output); + float scale_input = 0.0f; + size_t B, L1, S, H, Hk, h_each_group_len; + size_t sliding_window = 0; + size_t max_context_len = 0; + + q_input.reset(inputs[0]); + k_input.reset(inputs[1]); + v_input.reset(inputs[2]); + present_key.reset(inputs[ID_KCACHE]); + present_value.reset(inputs[ID_VCACHE]); + auto block_size = present_key.size(2); + + is_prompt = *inputs[ID_IS_PROMPT]->getDataAs() == 1; + max_context_len = static_cast(*inputs[ID_MAX_CONTEXT_LEN]->getDataAs()); + context_lens.reset(inputs[ID_CONTEXT_LENS]); + block_tables.reset(inputs[ID_BLOCK_TABLES]); + scale_input = *inputs[ID_SCALE]->getDataAs(); + + // q: [B, L1, H*S], kv: [B, L1, Hk*S] + // k_cache: [NUM_BLOCKS, Hk, 32, S] + // v_cache: [NUM_BLOCKS, Hk, 32, S] + // context_lens: [B] + // block_tables: [B, max_block_per_request] + B = k_input.size(0); + L1 = k_input.size(1); + Hk = present_key.size(1); + // The layout for per token per head for u8 kv cache: + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| + // The actual size needs to deduct scale and zeropoint. + S = present_value.size(3) - (present_value.m_dt == ov::element::Type_t::u8 ? sizeof(float) * 2 : 0); + H = q_input.size(2) / S; + h_each_group_len = 1; + if (Hk != H) { + h_each_group_len = H / Hk; + } + if (scale_input == 0.0f) + scale_input = 1.0f / sqrt(S); + + // TODO: enable block_size to be multiple of 32 + OPENVINO_ASSERT(block_size == 32, "CPU: block size must be 32, current: ", block_size); + OPENVINO_ASSERT(S % 16 == 0, "CPU: head size must be multiple of 16, current: ", S); + + q_input.assert_dims({B, L1, H * S}); + output_emb.assert_dims({B, L1, H * S}); + q_input = q_input.reshape({B, L1, H, S}).permute({0, 2, 1, 3}); + k_input = k_input.reshape({B, L1, Hk, S}).permute({0, 2, 1, 3}); + v_input = v_input.reshape({B, L1, Hk, S}).permute({0, 2, 1, 3}); + + _helper.init(H, S, Hk, h_each_group_len, block_size, sliding_window, scale_input, max_context_len); + + if (is_prompt) { + sliding_window = static_cast(*inputs[ID_SLIDING_WINDOW]->getDataAs()); + // always construct block_tables, max_context_len, context_lens from slot_mapping + { + PlainTensor slot_mapping; + slot_mapping.reset(inputs[ID_SLOT_MAPPING]); // [B, max_context_len] + block_tables.resize({B, div_up(max_context_len, block_size)}); + context_lens.resize({B}); + for (size_t i = 0; i < B; i++) { + context_lens.ptr()[i] = 0; + for (size_t j = 0; j < block_tables.m_dims[1]; j++) { + auto slot = slot_mapping.ptr(i)[j * block_size]; + block_tables.ptr(i)[j] = slot >= 0 ? slot / block_size : -1; + for (size_t k = j * block_size; k < (j + 1) * block_size && k < max_context_len; k++) { + if (slot_mapping.ptr(i)[k] < 0) + break; + context_lens.ptr()[i]++; + } + } + } + } + + // multi-token version + _kernel_multiple(q_input, present_key, present_value, output_emb, block_tables, max_context_len, context_lens); + } else { + context_lens.assert_dims({B}); + block_tables.assert_dims({B, 0}, true); + if (inputs.size() > 13) { + // first and second tokens mixed path + // subsequence_lens contains the length of each sequence + PlainTensor subsequence_lens; + subsequence_lens.reset(inputs[ID_SUBSEQUENCE_LENS]); + + _kernel_mixed(q_input, present_key, present_value, output_emb, block_tables, max_context_len, context_lens, subsequence_lens); + } else { + _kernel_single(q_input, present_key, present_value, output_emb, block_tables, max_context_len, context_lens); + } + } + } +}; +#endif + +std::shared_ptr make_pa_executor(ov::element::Type data_type, ov::element::Type kvcache_type) { + std::shared_ptr executor; + +#ifdef OPENVINO_ARCH_X86_64 + if (data_type == ov::element::bf16) { +#if defined(HAVE_AVX512F) + if (kvcache_type == ov::element::u8) { + executor = std::make_shared>(); + } else { + executor = std::make_shared>(); + } +#else + OPENVINO_THROW("make_pa_executor: bf16 needs avx512+ hardware."); +#endif + } else if (data_type == ov::element::f32) { + if (kvcache_type == ov::element::u8) { + executor = std::make_shared>(); + } else if (kvcache_type == ov::element::f16) { + executor = std::make_shared>(); + } else { + executor = std::make_shared>(); + } + } else { + OPENVINO_THROW("make_pa_executor: unsupported precision: ", data_type); + } +#else + OPENVINO_THROW("make_pa_executor: only support x64 platform"); +#endif + return executor; +} + +} // namespace XARCH +} // namespace Cpu +} // namespace Extensions +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp new file mode 100644 index 00000000000000..ed779dee13c96d --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include +#include +#include +#include +#include "cpu_memory.h" +#include "executor_pa_common.hpp" + +namespace ov { +namespace Extensions { +namespace Cpu { +namespace XARCH { + +std::shared_ptr make_pa_executor(ov::element::Type data_type, ov::element::Type kvcache_type); + +} // namespace XARCH +} // namespace Cpu +} // namespace Extensions +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.cpp new file mode 100644 index 00000000000000..63a8a0f7d24062 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.cpp @@ -0,0 +1,113 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include + +#include +#include +#include +#include +#include + +#include "openvino/core/type/bfloat16.hpp" +#include "openvino/core/parallel.hpp" +#include "executor_pa_common.hpp" +#include "utils/plain_tensor.hpp" + +namespace ov { +namespace Extensions { +namespace Cpu { + +using namespace ov; +using namespace ov::intel_cpu; +using namespace dnnl::impl; +using namespace dnnl::impl::cpu::x64; + +#ifdef OPENVINO_ARCH_X86_64 + +void TileConfig::reset(int palette, int _startRow, const std::vector>& _rows_columnsBytes) { + palette_id = palette; + startRow = _startRow; + unsigned long i; + for (i = 0; i < 14; i++) { + reserved[i] = 0; + } + for (i = 0; i < _rows_columnsBytes.size(); i++) { + rows[i] = _rows_columnsBytes[i].first; + cols[i] = _rows_columnsBytes[i].second; + } + for (; i < 16; i++) { + cols[i] = 0; + rows[i] = 0; + } +} + +TileConfiger::TileConfiger() : jit_generator(jit_name()) { + create_kernel(); +} + +void TileConfiger::generate() { + Xbyak::Label release; + test(abi_param1, abi_param1); + jz(release); + ldtilecfg(ptr[abi_param1]); + ret(); + L(release); + tilerelease(); + ret(); +} + +JitMatMulVecAMX::JitMatMulVecAMX(int head_size, int block_size) : jit_generator(jit_name()), m_head_size(head_size), m_block_size(block_size) { + create_kernel(); + m_tile_cfg.reset(1, + 0, + { + {16, 4}, // C:0 M x 1 (4b) + {16, 64}, // A:1 M x 32/64 (64b) + {16, 4}, // B:2 32/64 x 1 (4b) + {16, 4}, // B:3 + {16, 4}, // B:4 + {16, 4}, // B:5 + {16, 4}, // B:6 + {16, 4}, // B:7 + }); +} + +void JitMatMulVecAMX::generate() { + mov(reg_stride_A, m_head_size * 2); + mov(reg_stride_BC, 4); + const int kStep = 32; + if ((m_head_size % 32) != 0) + throw std::runtime_error("head size is not multiple of 32"); + if ((m_block_size % 16) != 0) + throw std::runtime_error("block size is not multiple of 16"); + auto num_B_tiles = m_head_size / kStep; + if (num_B_tiles > 6) + throw std::runtime_error("number of B tiles is bigger than 6"); + + /* + B(query) head_size x 1 + A(key) matrix : block_size x head_size C(dst) block_size x 1 + */ + // load query into B tiles + for (int i = 0; i < num_B_tiles; i++) { + tileloadd(Xbyak::Tmm(tmmB0.getIdx() + i), ptr[reg_q_addr + reg_stride_BC + i * 64]); + } + + for (int m = 0; m < m_block_size; m += 16) { + tilezero(tmmC); + for (int i = 0; i < num_B_tiles; i++) { + tileloadd(tmmA, ptr[reg_k_addr + reg_stride_A + i * 64]); + tdpbf16ps(tmmC, tmmA, Xbyak::Tmm(tmmB0.getIdx() + i)); + } + tilestored(ptr[reg_dst_addr + reg_stride_BC + m * sizeof(float)], tmmC); + add(reg_k_addr, m_head_size * 2 * 16); + } + ret(); +} + +#endif + +} // namespace Cpu +} // namespace Extensions +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp new file mode 100644 index 00000000000000..7b248f4643a5d0 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp @@ -0,0 +1,107 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include +#include +#include +#include +#include "cpu_memory.h" +#include "cpu/x64/cpu_isa_traits.hpp" +#include "cpu/x64/jit_generator.hpp" + +namespace ov { +namespace Extensions { +namespace Cpu { + +// this file will contain features that do not require multiple instantiation + +struct PagedAttentionExecutor { + // PagedAttention input index + static const size_t ID_Q = 0; + static const size_t ID_K = 1; + static const size_t ID_V = 2; + static const size_t ID_KCACHE = 3; + static const size_t ID_VCACHE = 4; + static const size_t ID_IS_PROMPT = 5; + static const size_t ID_SLOT_MAPPING = 6; + static const size_t ID_MAX_CONTEXT_LEN = 7; + static const size_t ID_CONTEXT_LENS = 8; + static const size_t ID_BLOCK_TABLES = 9; + static const size_t ID_SCALE = 10; + static const size_t ID_ALIBI_SLOPES = 11; + static const size_t ID_SLIDING_WINDOW = 12; + static const size_t ID_SUBSEQUENCE_LENS = 13; + virtual void execute(const std::vector& inputs, const ov::intel_cpu::MemoryPtr output) = 0; +}; + +#ifdef OPENVINO_ARCH_X86_64 + +// w = query * Key +// +// query: [1, S] +// Key : [block_size, S] +// w : [1, block_size] +// +// S is known at compile time +struct TileConfig { + uint8_t palette_id; + uint8_t startRow; + uint8_t reserved[14]; + uint16_t cols[16]; + uint8_t rows[16]; + void reset(int palette, int _startRow, const std::vector>& _rows_columnsBytes); +}; + +class TileConfiger : public dnnl::impl::cpu::x64::jit_generator { +public: + DECLARE_CPU_JIT_AUX_FUNCTIONS(TileConfiger) + TileConfiger(); + void generate() override; +}; + +class JitMatMulVecAMX : public dnnl::impl::cpu::x64::jit_generator { + void operator=(const JitMatMulVecAMX&); + +public: + DECLARE_CPU_JIT_AUX_FUNCTIONS(JitMatMulVecAMX) + int m_head_size; + int m_block_size; + TileConfiger m_tile_configer; + TileConfig m_tile_cfg; + JitMatMulVecAMX(int head_size, int block_size); + + void tile_config() { + m_tile_configer(&m_tile_cfg); + } + void tile_release() { + m_tile_configer(nullptr); + } + + // to save push/pop: do not use `abi_save_gpr_regs` + static constexpr auto abi_param_regs = dnnl::impl::cpu::x64::abi_param_regs; + Xbyak::Reg64 reg_q_addr = abi_param1; + Xbyak::Reg64 reg_k_addr = abi_param2; + Xbyak::Reg64 reg_dst_addr = abi_param3; + Xbyak::Reg64 reg_stride_A = rax; + Xbyak::Reg64 reg_stride_BC = r9; + + Xbyak::Tmm tmmC = tmm0; + Xbyak::Tmm tmmA = tmm1; + Xbyak::Tmm tmmB0 = tmm2; + Xbyak::Tmm tmmB1 = tmm3; + Xbyak::Tmm tmmB2 = tmm4; + Xbyak::Tmm tmmB3 = tmm5; + Xbyak::Tmm tmmB4 = tmm6; + Xbyak::Tmm tmmB5 = tmm7; + + void generate() override; +}; + +#endif + +} // namespace Cpu +} // namespace Extensions +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp index fe2101cb07e048..e4648ece365e9a 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp @@ -59,231 +59,6 @@ void cvt_copy(TA* dst, TB* src, size_t n) { } } -template -static void attn_acc_value_block(float* out, float* weight, T* v, size_t S, size_t block_size) { -#if defined(HAVE_AVX512F) - size_t j = 0; - for (; j + 4 <= block_size; j += 4) { - auto attn_w_vec0 = _mm512_set1_ps(weight[0]); - auto attn_w_vec1 = _mm512_set1_ps(weight[1]); - auto attn_w_vec2 = _mm512_set1_ps(weight[2]); - auto attn_w_vec3 = _mm512_set1_ps(weight[3]); - size_t i = 0; - for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { - auto v_out = mm512_uni_loadu_ps(out + i); - v_out = _mm512_fmadd_ps(attn_w_vec0, mm512_uni_loadu_ps(v + i), v_out); - v_out = _mm512_fmadd_ps(attn_w_vec1, mm512_uni_loadu_ps(v + i + S), v_out); - v_out = _mm512_fmadd_ps(attn_w_vec2, mm512_uni_loadu_ps(v + i + S * 2), v_out); - v_out = _mm512_fmadd_ps(attn_w_vec3, mm512_uni_loadu_ps(v + i + S * 3), v_out); - - _mm512_storeu_ps(out + i, v_out); - } - for (; i < S; i++) { - out[i] += weight[0] * v[i]; - out[i] += weight[1] * v[i + S]; - out[i] += weight[2] * v[i + S * 2]; - out[i] += weight[3] * v[i + S * 3]; - } - v += 4 * S; - weight += 4; - } - if (j + 2 <= block_size) { - auto attn_w_vec0 = _mm512_set1_ps(weight[0]); - auto attn_w_vec1 = _mm512_set1_ps(weight[1]); - size_t i = 0; - for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { - auto v_out = mm512_uni_loadu_ps(out + i); - v_out = _mm512_fmadd_ps(attn_w_vec0, mm512_uni_loadu_ps(v + i), v_out); - v_out = _mm512_fmadd_ps(attn_w_vec1, mm512_uni_loadu_ps(v + i + S), v_out); - - _mm512_storeu_ps(out + i, v_out); - } - for (; i < S; i++) { - out[i] += weight[0] * v[i]; - out[i] += weight[1] * v[i + S]; - } - v += 2 * S; - weight += 2; - j += 2; - } - if (j < block_size) { - auto attn_w_vec0 = _mm512_set1_ps(weight[0]); - size_t i = 0; - for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { - auto v_out = mm512_uni_loadu_ps(out + i); - v_out = _mm512_fmadd_ps(attn_w_vec0, mm512_uni_loadu_ps(v + i), v_out); - - _mm512_storeu_ps(out + i, v_out); - } - for (; i < S; i++) { - out[i] += weight[0] * v[i]; - } - } - return; -#elif defined(HAVE_AVX2) - size_t j = 0; - for (; j + 4 <= block_size; j += 4) { - auto attn_w_vec0 = _mm256_set1_ps(weight[0]); - auto attn_w_vec1 = _mm256_set1_ps(weight[1]); - auto attn_w_vec2 = _mm256_set1_ps(weight[2]); - auto attn_w_vec3 = _mm256_set1_ps(weight[3]); - size_t i = 0; - for (; i + vec_len_f32_avx2 <= S; i += vec_len_f32_avx2) { - auto v_out = mm256_uni_loadu_ps(out + i); - v_out = _mm256_fmadd_ps(attn_w_vec0, mm256_uni_loadu_ps(v + i), v_out); - v_out = _mm256_fmadd_ps(attn_w_vec1, mm256_uni_loadu_ps(v + i + S), v_out); - v_out = _mm256_fmadd_ps(attn_w_vec2, mm256_uni_loadu_ps(v + i + S * 2), v_out); - v_out = _mm256_fmadd_ps(attn_w_vec3, mm256_uni_loadu_ps(v + i + S * 3), v_out); - - mm256_uni_storeu_ps(out + i, v_out); - } - for (; i < S; i++) { - out[i] += weight[0] * v[i]; - out[i] += weight[1] * v[i + S]; - out[i] += weight[2] * v[i + S * 2]; - out[i] += weight[3] * v[i + S * 3]; - } - v += 4 * S; - weight += 4; - } - if (j + 2 <= block_size) { - auto attn_w_vec0 = _mm256_set1_ps(weight[0]); - auto attn_w_vec1 = _mm256_set1_ps(weight[1]); - size_t i = 0; - for (; i + vec_len_f32_avx2 <= S; i += vec_len_f32_avx2) { - auto v_out = mm256_uni_loadu_ps(out + i); - v_out = _mm256_fmadd_ps(attn_w_vec0, mm256_uni_loadu_ps(v + i), v_out); - v_out = _mm256_fmadd_ps(attn_w_vec1, mm256_uni_loadu_ps(v + i + S), v_out); - - mm256_uni_storeu_ps(out + i, v_out); - } - for (; i < S; i++) { - out[i] += weight[0] * v[i]; - out[i] += weight[1] * v[i + S]; - } - v += 2 * S; - weight += 2; - j += 2; - } - if (j < block_size) { - auto attn_w_vec0 = _mm256_set1_ps(weight[0]); - size_t i = 0; - for (; i + vec_len_f32_avx2 <= S; i += vec_len_f32_avx2) { - auto v_out = mm256_uni_loadu_ps(out + i); - v_out = _mm256_fmadd_ps(attn_w_vec0, mm256_uni_loadu_ps(v + i), v_out); - - mm256_uni_storeu_ps(out + i, v_out); - } - for (; i < S; i++) { - out[i] += weight[0] * v[i]; - } - } - return; -#endif - for (size_t j = 0; j < block_size; j++) { - for (size_t i = 0; i < S; i++) { - out[i] += weight[j] * v[i]; - } - v += S; - } -} - -static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S, size_t block_size) { - // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) -#if defined(HAVE_AVX512F) - size_t j = 0; - for (; j + 4 <= block_size; j += 4) { - auto v_f0 = reinterpret_cast(v); - auto v_f1 = reinterpret_cast(v + S + 8); - auto v_f2 = reinterpret_cast(v + 2 * (S + 8)); - auto v_f3 = reinterpret_cast(v + 3 * (S + 8)); - auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); - auto attn_w_vec1 = _mm512_set1_ps(weight[1] * v_f1[0]); - auto attn_w_vec2 = _mm512_set1_ps(weight[2] * v_f2[0]); - auto attn_w_vec3 = _mm512_set1_ps(weight[3] * v_f3[0]); - auto zp0 = _mm512_set1_ps(v_f0[1]); - auto zp1 = _mm512_set1_ps(v_f1[1]); - auto zp2 = _mm512_set1_ps(v_f2[1]); - auto zp3 = _mm512_set1_ps(v_f3[1]); - size_t i = 0; - v += 8; - for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { - auto v_out = mm512_uni_loadu_ps(out + i); - auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i)))), zp0); - auto v1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i + S + 8)))), zp1); - auto v2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i + 2 * (S + 8))))), zp2); - auto v3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i + 3 * (S + 8))))), zp3); - v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); - v_out = _mm512_fmadd_ps(attn_w_vec1, v1, v_out); - v_out = _mm512_fmadd_ps(attn_w_vec2, v2, v_out); - v_out = _mm512_fmadd_ps(attn_w_vec3, v3, v_out); - - _mm512_storeu_ps(out + i, v_out); - } - for (; i < S; i++) { - out[i] += weight[0] * (v[i] - v_f0[1]) * v_f0[0]; - out[i] += weight[1] * (v[i + S + 8] - v_f1[1]) * v_f1[0]; - out[i] += weight[2] * (v[i + 2 * (S + 8)] - v_f2[1]) * v_f2[0]; - out[i] += weight[3] * (v[i + 3 * (S + 8)] - v_f3[1]) * v_f3[0]; - } - v += 4 * (S + 8) - 8; - weight += 4; - } - for (; j < block_size; j++) { - auto v_f0 = reinterpret_cast(v); - auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); - auto zp0 = _mm512_set1_ps(v_f0[1]); - size_t i = 0; - v += 8; - for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { - auto v_out = mm512_uni_loadu_ps(out + i); - auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i)))), zp0); - v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); - - _mm512_storeu_ps(out + i, v_out); - } - for (; i < S; i++) { - out[i] += weight[0] * (v[i] - v_f0[1]) * v_f0[0]; - } - v += S; - weight++; - } - return; -#elif defined(HAVE_AVX2) - size_t j = 0; - for (; j < block_size; j++) { - auto v_f0 = reinterpret_cast(v); - auto attn_w_vec0 = _mm256_set1_ps(weight[0] * v_f0[0]); - auto zp0 = _mm256_set1_ps(v_f0[1]); - size_t i = 0; - v += 8; - for (; i + vec_len_f32_avx2 <= S; i += vec_len_f32_avx2) { - auto v_out = mm256_uni_loadu_ps(out + i); - auto v0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(v + i)))), zp0); - v_out = _mm256_fmadd_ps(attn_w_vec0, v0, v_out); - - mm256_uni_storeu_ps(out + i, v_out); - } - for (; i < S; i++) { - out[i] += weight[0] * (v[i] - v_f0[1]) * v_f0[0]; - } - v += S; - weight++; - } - return; -#endif - for (size_t j = 0; j < block_size; j++) { - auto v0 = reinterpret_cast(v); - v += 8; - for (size_t i = 0; i < S; i++) { - out[i] += weight[j] * (v[i] - v0[1]) * v0[0]; - } - v += S; - } -} - template static void attn_acc_value(float* out, float weight, T* v, size_t S, float* scale, float* zp) { size_t i = 0; @@ -792,271 +567,6 @@ static float dot_product(TA* a, uint8_t* b, size_t n, float* scale, float* zp, f #endif } -template -static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_size) { -#if defined(HAVE_AVX512F) - size_t j = 0; - for (; j + 4 <= block_size; j += 4) { - auto vsum0 = _mm512_setzero_ps(); - auto vsum1 = _mm512_setzero_ps(); - auto vsum2 = _mm512_setzero_ps(); - auto vsum3 = _mm512_setzero_ps(); - size_t i = 0; - for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { - auto va = mm512_uni_loadu_ps(a + i); - vsum0 = _mm512_fmadd_ps(va, mm512_uni_loadu_ps(b + i), vsum0); - vsum1 = _mm512_fmadd_ps(va, mm512_uni_loadu_ps(b + i + n), vsum1); - vsum2 = _mm512_fmadd_ps(va, mm512_uni_loadu_ps(b + i + 2 * n), vsum2); - vsum3 = _mm512_fmadd_ps(va, mm512_uni_loadu_ps(b + i + 3 * n), vsum3); - } - float sum0 = _mm512_reduce_add_ps(vsum0); - float sum1 = _mm512_reduce_add_ps(vsum1); - float sum2 = _mm512_reduce_add_ps(vsum2); - float sum3 = _mm512_reduce_add_ps(vsum3); - for (; i < n; i++) { - sum0 += a[i] * b[i]; - sum1 += a[i] * b[i + n]; - sum2 += a[i] * b[i + 2 * n]; - sum3 += a[i] * b[i + 3 * n]; - } - c[0] = sum0; - c[1] = sum1; - c[2] = sum2; - c[3] = sum3; - c += 4; - b += 4 * n; - } - for (; j < block_size; j++) { - auto vsum = _mm512_setzero_ps(); - size_t i = 0; - for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { - auto va = mm512_uni_loadu_ps(a + i); - vsum = _mm512_fmadd_ps(va, mm512_uni_loadu_ps(b + i), vsum); - } - float sum = _mm512_reduce_add_ps(vsum); - for (; i < n; i++) { - sum += a[i] * b[i]; - } - b += n; - *c++ = sum; - } - return; -#elif defined(HAVE_AVX2) - size_t j = 0; - for (; j + 4 <= block_size; j += 4) { - auto vsum0 = _mm256_set1_ps(0.0f); - auto vsum1 = _mm256_set1_ps(0.0f); - auto vsum2 = _mm256_set1_ps(0.0f); - auto vsum3 = _mm256_set1_ps(0.0f); - size_t i = 0; - for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { - auto va = mm256_uni_loadu_ps(a + i); - vsum0 = _mm256_fmadd_ps(va, mm256_uni_loadu_ps(b + i), vsum0); - vsum1 = _mm256_fmadd_ps(va, mm256_uni_loadu_ps(b + i + n), vsum1); - vsum2 = _mm256_fmadd_ps(va, mm256_uni_loadu_ps(b + i + 2 * n), vsum2); - vsum3 = _mm256_fmadd_ps(va, mm256_uni_loadu_ps(b + i + 3 * n), vsum3); - } - hsum(vsum0); - hsum(vsum1); - hsum(vsum2); - hsum(vsum3); - float sum0 = _mm256_cvtss_f32(vsum0); - float sum1 = _mm256_cvtss_f32(vsum1); - float sum2 = _mm256_cvtss_f32(vsum2); - float sum3 = _mm256_cvtss_f32(vsum3); - for (; i < n; i++) { - sum0 += a[i] * b[i]; - sum1 += a[i] * b[i + n]; - sum2 += a[i] * b[i + 2 * n]; - sum3 += a[i] * b[i + 3 * n]; - } - c[0] = sum0; - c[1] = sum1; - c[2] = sum2; - c[3] = sum3; - c += 4; - b += 4 * n; - } - for (; j < block_size; j++) { - auto vsum = _mm256_set1_ps(0.0f); - size_t i = 0; - for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { - auto va = mm256_uni_loadu_ps(a + i); - vsum = _mm256_fmadd_ps(va, mm256_uni_loadu_ps(b + i), vsum); - } - hsum(vsum); - float sum = _mm256_cvtss_f32(vsum); - for (; i < n; i++) { - sum += a[i] * b[i]; - } - b += n; - *c++ = sum; - } - return; -#endif - for (size_t j = 0; j < block_size; j++) { - float sum = 0; - for (size_t i = 0; i < n; i++) { - sum += a[i] * b[i]; - } - b += n; - *c++ = sum; - } -} - -template -static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t block_size) { - // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) -#if defined(HAVE_AVX512F) - size_t j = 0; - for (; j + 4 <= block_size; j += 4) { - auto vsum0 = _mm512_setzero_ps(); - auto vsum1 = _mm512_setzero_ps(); - auto vsum2 = _mm512_setzero_ps(); - auto vsum3 = _mm512_setzero_ps(); - auto b0 = reinterpret_cast(b); - auto b1 = reinterpret_cast(b + n + 8); - auto b2 = reinterpret_cast(b + (n + 8) * 2); - auto b3 = reinterpret_cast(b + (n + 8) * 3); - auto v_zp0 = _mm512_set1_ps(b0[1]); - auto v_zp1 = _mm512_set1_ps(b1[1]); - auto v_zp2 = _mm512_set1_ps(b2[1]); - auto v_zp3 = _mm512_set1_ps(b3[1]); - size_t i = 0; - b += 8; - for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { - auto va = mm512_uni_loadu_ps(a + i); - auto vb0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i)))), v_zp0); - auto vb1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i + n + 8)))), v_zp1); - auto vb2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i + 2 * (n + 8))))), v_zp2); - auto vb3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i + 3 * (n + 8))))), v_zp3); - - vsum0 = _mm512_fmadd_ps(va, vb0, vsum0); - vsum1 = _mm512_fmadd_ps(va, vb1, vsum1); - vsum2 = _mm512_fmadd_ps(va, vb2, vsum2); - vsum3 = _mm512_fmadd_ps(va, vb3, vsum3); - } - float sum0 = _mm512_reduce_add_ps(vsum0); - float sum1 = _mm512_reduce_add_ps(vsum1); - float sum2 = _mm512_reduce_add_ps(vsum2); - float sum3 = _mm512_reduce_add_ps(vsum3); - for (; i < n; i++) { - sum0 += a[i] * (b[i] - b0[1]); - sum1 += a[i] * (b[i + n + 8] - b1[1]); - sum2 += a[i] * (b[i + 2 * (n + 8)] - b2[1]); - sum3 += a[i] * (b[i + 3 * (n + 8)] - b3[1]); - } - c[0] = sum0 * b0[0]; - c[1] = sum1 * b1[0]; - c[2] = sum2 * b2[0]; - c[3] = sum3 * b3[0]; - c += 4; - b += 4 * (n + 8) - 8; - } - for (; j < block_size; j++) { - auto vsum = _mm512_setzero_ps(); - auto b0 = reinterpret_cast(b); - auto v_zp = _mm512_set1_ps(b0[1]); - size_t i = 0; - b += 8; - for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { - auto va = mm512_uni_loadu_ps(a + i); - auto vb = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i)))), v_zp); - vsum = _mm512_fmadd_ps(va, vb, vsum); - } - float sum = _mm512_reduce_add_ps(vsum); - for (; i < n; i++) { - sum += a[i] * (b[i] - b0[1]); - } - b += n; - *c++ = sum * b0[0]; - } - return; -#elif defined(HAVE_AVX2) - size_t j = 0; - for (; j + 4 <= block_size; j += 4) { - auto vsum0 = _mm256_setzero_ps(); - auto vsum1 = _mm256_setzero_ps(); - auto vsum2 = _mm256_setzero_ps(); - auto vsum3 = _mm256_setzero_ps(); - auto b0 = reinterpret_cast(b); - auto b1 = reinterpret_cast(b + n + 8); - auto b2 = reinterpret_cast(b + (n + 8) * 2); - auto b3 = reinterpret_cast(b + (n + 8) * 3); - auto v_zp0 = _mm256_set1_ps(b0[1]); - auto v_zp1 = _mm256_set1_ps(b1[1]); - auto v_zp2 = _mm256_set1_ps(b2[1]); - auto v_zp3 = _mm256_set1_ps(b3[1]); - size_t i = 0; - b += 8; - for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { - auto va = mm256_uni_loadu_ps(a + i); - auto vb0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i)))), v_zp0); - auto vb1 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + n + 8)))), v_zp1); - auto vb2 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + 2 * (n + 8))))), v_zp2); - auto vb3 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + 3 * (n + 8))))), v_zp3); - - vsum0 = _mm256_fmadd_ps(va, vb0, vsum0); - vsum1 = _mm256_fmadd_ps(va, vb1, vsum1); - vsum2 = _mm256_fmadd_ps(va, vb2, vsum2); - vsum3 = _mm256_fmadd_ps(va, vb3, vsum3); - } - hsum(vsum0); - hsum(vsum1); - hsum(vsum2); - hsum(vsum3); - float sum0 = _mm256_cvtss_f32(vsum0); - float sum1 = _mm256_cvtss_f32(vsum1); - float sum2 = _mm256_cvtss_f32(vsum2); - float sum3 = _mm256_cvtss_f32(vsum3); - for (; i < n; i++) { - sum0 += a[i] * (b[i] - b0[1]); - sum1 += a[i] * (b[i + n + 8] - b1[1]); - sum2 += a[i] * (b[i + 2 * (n + 8)] - b2[1]); - sum3 += a[i] * (b[i + 3 * (n + 8)] - b3[1]); - } - c[0] = sum0 * b0[0]; - c[1] = sum1 * b1[0]; - c[2] = sum2 * b2[0]; - c[3] = sum3 * b3[0]; - c += 4; - b += 4 * (n + 8) - 8; - } - for (; j < block_size; j++) { - auto vsum = _mm256_setzero_ps(); - auto b0 = reinterpret_cast(b); - auto v_zp = _mm256_set1_ps(b0[1]); - size_t i = 0; - b += 8; - for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { - auto va = mm256_uni_loadu_ps(a + i); - auto vb = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i)))), v_zp); - vsum = _mm256_fmadd_ps(va, vb, vsum); - } - hsum(vsum); - float sum = _mm256_cvtss_f32(vsum); - for (; i < n; i++) { - sum += a[i] * (b[i] - b0[1]); - } - b += n; - *c++ = sum * b0[0]; - } - return; -#endif - for (size_t j = 0; j < block_size; j++) { - float sum = 0; - auto b0 = reinterpret_cast(b); - b += 8; - for (size_t i = 0; i < n; i++) { - sum += a[i] * (b[i] - b0[1]); - } - b += n; - *c++ = sum * b0[0]; - } -} - template static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_stride) { size_t i = 0; @@ -1103,8 +613,6 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, const ov::intel_cpu::PlainTensor& alibi_mask, const ov::intel_cpu::PlainTensor& attention_mask, const ov::intel_cpu::PlainTensor& beams, - size_t max_context_len, - const ov::intel_cpu::PlainTensor& context_lens, ov::intel_cpu::PlainTensor& output_emb, ov::intel_cpu::PlainTensor& buf_attn_w, ov::intel_cpu::PlainTensor& buf_attn_score, @@ -1122,27 +630,20 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, auto S = query.size(3); auto h_group_num = present_value.size(1); size_t h_each_group_len = 1; - bool is_pagedattn = context_lens; - size_t block_size = present_value.size(2); if (h_group_num != H) { h_each_group_len = H / h_group_num; } if (d_scale == 0.0f) d_scale = 1.0f / sqrt(S); auto nthr = parallel_get_max_threads(); - size_t kv_len; - if (is_pagedattn) { - kv_len = max_context_len; - } else { - kv_len = present_key.size(2); - } + auto kv_len = present_key.size(2); #if defined(HAVE_AVX2) && !defined(HAVE_AVX512F) // avx2 will pre-compute the zero point and try to save the sub instruction in the dot_product, // but it seems not necessary for avx512. Possible reason may be that for avx2 the cost of dot_product // is larger than the memory access time, but for avx512 is not and the cost of pre-compute is a pure increase. bool pastkv_is_int8 = past_k_scale_zp; - if (pastkv_is_int8 && !is_pagedattn) { + if (pastkv_is_int8) { // be sure no false sharing head_sum.resize({B, H, q_len, 16}); parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t pq) { @@ -1151,271 +652,154 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, } #endif - // TODO: refactor to seperate files - if (is_pagedattn) { - // if present_key is true, it means q*k is already computed in the caller - if (present_key) { - if (B >= static_cast(nthr)) { - parallel_for2d_dynamic(B, beams.m_dims[1], [&](size_t b, size_t pk_in_blocks) { - auto context_len = static_cast(context_lens.ptr()[b]); - // kv_len must be valid - auto pk = pk_in_blocks * block_size; - if (pk < context_len) { - auto block_number = beams.ptr(b)[pk_in_blocks]; - for (size_t h_group = 0; h_group < h_group_num; h_group++) { - for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { - dot_product_block(query.ptr(b, h, pq), present_key.ptr(block_number, h_group), - buf_attn_w.ptr(b, h, pq) + pk, S, std::min(block_size, context_len - pk)); - } - } - } + parallel_nt_static(nthr, [&](const size_t ithr, const size_t nthr) { + size_t start{0}, end{0}; + splitter(B * h_group_num * kv_len, nthr, ithr, start, end); + + size_t b, h_group, pk; + if (start < end) { + parallel_it_init(start, pk, kv_len, b, B, h_group, h_group_num); + if (q_len == 1 && h_each_group_len == 1) { + if (B == 1) { + // the memory will be continuous when b==1 + for (size_t iwork = start; iwork < end; ++iwork) { + auto p = past_k_scale_zp.ptr(pk, 0, h_group); + auto p_k = present_key.ptr(0, h_group, pk); + prefetch_bytes(S, _MM_HINT_T0, 4096, p_k); + buf_attn_w.ptr(0, h_group, 0)[pk] = + dot_product(query.ptr(0, h_group), p_k, + S, p, p + 1, head_sum.ptr(0, h_group)); + parallel_it_step(pk, kv_len, b, B, h_group, h_group_num); } - }); - } else { - parallel_for3d_dynamic(B, beams.m_dims[1], h_group_num, [&](size_t b, size_t pk_in_blocks, size_t h_group) { - auto context_len = static_cast(context_lens.ptr()[b]); - // kv_len must be valid - auto pk = pk_in_blocks * block_size; - if (pk < context_len) { - auto block_number = beams.ptr(b)[pk_in_blocks]; - for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { - dot_product_block(query.ptr(b, h, pq), present_key.ptr(block_number, h_group), - buf_attn_w.ptr(b, h, pq) + pk, S, std::min(block_size, context_len - pk)); - } - } + } else { + for (size_t iwork = start; iwork < end; ++iwork) { + auto b_kv = beams ? beams.ptr(b)[pk] : b; + auto p = past_k_scale_zp.ptr(pk, b_kv, h_group); + auto p_k = present_key.ptr(b_kv, h_group, pk); + buf_attn_w.ptr(b, h_group, 0)[pk] = + dot_product(query.ptr(b, h_group), p_k, + S, p, p + 1, head_sum.ptr(b, h_group)); + parallel_it_step(pk, kv_len, b, B, h_group, h_group_num); } - }); - } - } - - parallel_for3d_dynamic(B, H, q_len, [&](size_t b, size_t h, size_t pq) { - auto cur_kv_len = static_cast(context_lens.ptr()[b]); - auto ncausal = cur_kv_len; - // apply attention mask & sofmax - float* alibi_ptr = alibi_mask ? &alibi_mask.at({b, h, pq, 0}, true) : nullptr; - uint8_t* attn_mask_ptr = nullptr; - auto attn_mask_prec = attention_mask.get_precision(); - if (attention_mask) - attn_mask_ptr = reinterpret_cast(&attention_mask.at({b, h, pq, 0}, true)); - uint8_t* cmask_ptr = causal_mask ? &causal_mask.at({b, h, pq, 0}, true) : nullptr; - attn_softmax_kernel(buf_attn_w.ptr(b, h, pq), - buf_attn_w.ptr(b, h, pq), - d_scale, - alibi_ptr, - attn_mask_ptr, - cmask_ptr, - select_nfltmax_at_0, - ncausal, - cur_kv_len, - attn_mask_prec, - ov::element::f32); - }); - - // attn_w * V - // there are enough works for each thread - if (B >= static_cast(nthr)) { - buf_attn_score.resize({static_cast(nthr), q_len, h_each_group_len, S}); - parallel_for2d_dynamic(B, h_group_num, [&](size_t b, size_t h_group) { - auto ithr = parallel_get_thread_num(); - auto context_len = static_cast(context_lens.ptr()[b]); - memset(buf_attn_score.ptr(ithr), 0, q_len * h_each_group_len * S * sizeof(float)); - for (size_t pv = 0; pv < context_len; pv += block_size) { - size_t pv_in_blocks = pv / block_size; - auto block_number = beams.ptr(b)[pv_in_blocks]; - auto* v = present_value.ptr(block_number, h_group); + } + } else { + for (size_t iwork = start; iwork < end; ++iwork) { + auto b_kv = beams ? beams.ptr(b)[pk] : b; for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = h_group * h_each_group_len, group_idx = 0; h < (h_group + 1) * h_each_group_len; h++, group_idx++) { - attn_acc_value_block(buf_attn_score.ptr(ithr, pq, group_idx), - buf_attn_w.ptr(b, h, pq) + pv, - v, - S, - std::min(block_size, context_len - pv)); + auto p = past_k_scale_zp.ptr(pk, b_kv, h_group); + for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { + buf_attn_w.ptr(b, h, pq)[pk] = + dot_product(query.ptr(b, h, pq), present_key.ptr(b_kv, h_group, pk), + S, p, p + 1, head_sum.ptr(b, h, pq)); } } + parallel_it_step(pk, kv_len, b, B, h_group, h_group_num); } - // convert to dst - for (size_t pq = 0; pq < q_len; pq++) - for (size_t h = h_group * h_each_group_len, group_idx = 0; h < (h_group + 1) * h_each_group_len; h++, group_idx++) - cvt_copy(output_emb.ptr(b, pq, h * S), buf_attn_score.ptr(ithr, pq, group_idx), S); - }); - return; + } } - buf_attn_score.resize({static_cast(nthr), B, q_len, H, S}); - // buf_attn_w {B, H, q_len, kv_len} - parallel_nt_static(nthr, [&](const size_t ithr, const size_t nthr) { - memset(buf_attn_score.ptr(ithr, 0, 0, 0, 0), 0, buf_attn_score.stride(0) * sizeof(float)); - }); + }); + + parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t pq) { + auto cur_kv_len = kv_len; + auto ncausal = auto_causal ? (cur_kv_len - q_len + pq + 1) : cur_kv_len; + // apply attention mask & sofmax + float* alibi_ptr = alibi_mask ? &alibi_mask.at({b, h, pq, 0}, true) : nullptr; + uint8_t* attn_mask_ptr = nullptr; + auto attn_mask_prec = attention_mask.get_precision(); + if (attention_mask) + attn_mask_ptr = reinterpret_cast(&attention_mask.at({b, h, pq, 0}, true)); + uint8_t* cmask_ptr = causal_mask ? &causal_mask.at({b, h, pq, 0}, true) : nullptr; + attn_softmax_kernel(buf_attn_w.ptr(b, h, pq), + buf_attn_w.ptr(b, h, pq), + d_scale, + alibi_ptr, + attn_mask_ptr, + cmask_ptr, + select_nfltmax_at_0, + ncausal, + cur_kv_len, + attn_mask_prec, + ov::element::f32); + }); - auto kv_len_in_blocks = beams.m_dims[1]; - parallel_for3d_dynamic(B, kv_len_in_blocks, h_group_num, [&](size_t b, size_t pv_in_blocks, size_t h_group) { + // attn_w * V + // Fast Path if there are enough works for each thread + if (B >= static_cast(nthr)) { + buf_attn_score.resize({static_cast(nthr), q_len, h_each_group_len, S}); + parallel_for2d(B, h_group_num, [&](size_t b, size_t h_group) { auto ithr = parallel_get_thread_num(); - auto context_len = static_cast(context_lens.ptr()[b]); - auto pv = pv_in_blocks * block_size; - // kv_len must be valid - if (pv < context_len) { - auto block_number = beams.ptr(b)[pv_in_blocks]; - auto* v = present_value.ptr(block_number, h_group); + memset(buf_attn_score.ptr(ithr), 0, q_len * h_each_group_len * S * sizeof(float)); + for (size_t pv = 0; pv < kv_len; pv++) { + auto b_kv = beams ? beams.ptr(b)[pv] : b; + auto* v = present_value.ptr(b_kv, h_group, pv); + auto p = past_v_scale_zp.ptr(pv, b_kv, h_group); for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { - attn_acc_value_block(buf_attn_score.ptr(ithr, b, pq, h), - buf_attn_w.ptr(b, h, pq) + pv, - v, - S, - std::min(block_size, context_len - pv)); + for (size_t h = h_group * h_each_group_len, group_idx = 0; h < (h_group + 1) * h_each_group_len; h++, group_idx++) { + attn_acc_value(buf_attn_score.ptr(ithr, pq, group_idx), + buf_attn_w.ptr(b, h, pq)[pv], + v, + S, + p + 0, + p + 1); } } } - }); - } else { - parallel_nt_static(nthr, [&](const size_t ithr, const size_t nthr) { - size_t start{0}, end{0}; - splitter(B * h_group_num * kv_len, nthr, ithr, start, end); - - size_t b, h_group, pk; - if (start < end) { - parallel_it_init(start, pk, kv_len, b, B, h_group, h_group_num); - if (q_len == 1 && h_each_group_len == 1) { - if (B == 1) { - // the memory will be continuous when b==1 - for (size_t iwork = start; iwork < end; ++iwork) { - auto p = past_k_scale_zp.ptr(pk, 0, h_group); - auto p_k = present_key.ptr(0, h_group, pk); - prefetch_bytes(S, _MM_HINT_T0, 4096, p_k); - buf_attn_w.ptr(0, h_group, 0)[pk] = - dot_product(query.ptr(0, h_group), p_k, - S, p, p + 1, head_sum.ptr(0, h_group)); - parallel_it_step(pk, kv_len, b, B, h_group, h_group_num); - } - } else { - for (size_t iwork = start; iwork < end; ++iwork) { - auto b_kv = beams ? beams.ptr(b)[pk] : b; - auto p = past_k_scale_zp.ptr(pk, b_kv, h_group); - auto p_k = present_key.ptr(b_kv, h_group, pk); - buf_attn_w.ptr(b, h_group, 0)[pk] = - dot_product(query.ptr(b, h_group), p_k, - S, p, p + 1, head_sum.ptr(b, h_group)); - parallel_it_step(pk, kv_len, b, B, h_group, h_group_num); - } - } - } else { - for (size_t iwork = start; iwork < end; ++iwork) { - auto b_kv = beams ? beams.ptr(b)[pk] : b; - for (size_t pq = 0; pq < q_len; pq++) { - auto p = past_k_scale_zp.ptr(pk, b_kv, h_group); - for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { - buf_attn_w.ptr(b, h, pq)[pk] = - dot_product(query.ptr(b, h, pq), present_key.ptr(b_kv, h_group, pk), - S, p, p + 1, head_sum.ptr(b, h, pq)); - } - } - parallel_it_step(pk, kv_len, b, B, h_group, h_group_num); - } + // convert to dst + for (size_t pq = 0; pq < q_len; pq++) { + for (size_t h = h_group * h_each_group_len, group_idx = 0; h < (h_group + 1) * h_each_group_len; + h++, group_idx++) { + auto* dst = has_out_transpose ? output_emb.ptr(b, pq, h * S) : output_emb.ptr(b, h, pq); + cvt_copy(dst, buf_attn_score.ptr(ithr, pq, group_idx), S); } } }); - parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t pq) { - auto cur_kv_len = kv_len; - auto ncausal = auto_causal ? (cur_kv_len - q_len + pq + 1) : cur_kv_len; - // apply attention mask & sofmax - float* alibi_ptr = alibi_mask ? &alibi_mask.at({b, h, pq, 0}, true) : nullptr; - uint8_t* attn_mask_ptr = nullptr; - auto attn_mask_prec = attention_mask.get_precision(); - if (attention_mask) - attn_mask_ptr = reinterpret_cast(&attention_mask.at({b, h, pq, 0}, true)); - uint8_t* cmask_ptr = causal_mask ? &causal_mask.at({b, h, pq, 0}, true) : nullptr; - attn_softmax_kernel(buf_attn_w.ptr(b, h, pq), - buf_attn_w.ptr(b, h, pq), - d_scale, - alibi_ptr, - attn_mask_ptr, - cmask_ptr, - select_nfltmax_at_0, - ncausal, - cur_kv_len, - attn_mask_prec, - ov::element::f32); - }); - // attn_w * V - // Fast Path if there are enough works for each thread - if (B >= static_cast(nthr)) { - buf_attn_score.resize({static_cast(nthr), q_len, h_each_group_len, S}); - parallel_for2d(B, h_group_num, [&](size_t b, size_t h_group) { - auto ithr = parallel_get_thread_num(); - memset(buf_attn_score.ptr(ithr), 0, q_len * h_each_group_len * S * sizeof(float)); - for (size_t pv = 0; pv < kv_len; pv++) { + return; + } + + buf_attn_score.resize({static_cast(nthr), B, q_len, H, S}); + // buf_attn_w {B, H, q_len, kv_len} + parallel_nt_static(nthr, [&](const size_t ithr, const size_t nthr) { + size_t start{0}, end{0}; + splitter(B * h_group_num * kv_len, nthr, ithr, start, end); + + memset(buf_attn_score.ptr(ithr, 0, 0, 0, 0), 0, buf_attn_score.stride(0) * sizeof(float)); + + size_t b, h_group, pv; + if (start < end) { + parallel_it_init(start, pv, kv_len, b, B, h_group, h_group_num); + if (q_len == 1 && h_each_group_len == 1) { + for (size_t iwork = start; iwork < end; ++iwork) { auto b_kv = beams ? beams.ptr(b)[pv] : b; auto* v = present_value.ptr(b_kv, h_group, pv); auto p = past_v_scale_zp.ptr(pv, b_kv, h_group); - for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = h_group * h_each_group_len, group_idx = 0; h < (h_group + 1) * h_each_group_len; h++, group_idx++) { - attn_acc_value(buf_attn_score.ptr(ithr, pq, group_idx), - buf_attn_w.ptr(b, h, pq)[pv], - v, - S, - p + 0, - p + 1); - } - } - } - // convert to dst - for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = h_group * h_each_group_len, group_idx = 0; h < (h_group + 1) * h_each_group_len; - h++, group_idx++) { - auto* dst = has_out_transpose ? output_emb.ptr(b, pq, h * S) : output_emb.ptr(b, h, pq); - cvt_copy(dst, buf_attn_score.ptr(ithr, pq, group_idx), S); - } + attn_acc_value(buf_attn_score.ptr(ithr, b, 0, h_group), + buf_attn_w.ptr(b, h_group, 0, pv)[0], + v, + S, + p + 0, + p + 1); + parallel_it_step(pv, kv_len, b, B, h_group, h_group_num); } - }); - return; - } - - buf_attn_score.resize({static_cast(nthr), B, q_len, H, S}); - // buf_attn_w {B, H, q_len, kv_len} - parallel_nt_static(nthr, [&](const size_t ithr, const size_t nthr) { - size_t start{0}, end{0}; - splitter(B * h_group_num * kv_len, nthr, ithr, start, end); - - memset(buf_attn_score.ptr(ithr, 0, 0, 0, 0), 0, buf_attn_score.stride(0) * sizeof(float)); - - size_t b, h_group, pv; - if (start < end) { - parallel_it_init(start, pv, kv_len, b, B, h_group, h_group_num); - if (q_len == 1 && h_each_group_len == 1) { - for (size_t iwork = start; iwork < end; ++iwork) { - auto b_kv = beams ? beams.ptr(b)[pv] : b; - auto* v = present_value.ptr(b_kv, h_group, pv); - auto p = past_v_scale_zp.ptr(pv, b_kv, h_group); - attn_acc_value(buf_attn_score.ptr(ithr, b, 0, h_group), - buf_attn_w.ptr(b, h_group, 0, pv)[0], - v, - S, - p + 0, - p + 1); - parallel_it_step(pv, kv_len, b, B, h_group, h_group_num); - } - } else { - for (size_t iwork = start; iwork < end; ++iwork) { - auto b_kv = beams ? beams.ptr(b)[pv] : b; - auto* v = present_value.ptr(b_kv, h_group, pv); - auto p = past_v_scale_zp.ptr(pv, b_kv, h_group); - for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { - attn_acc_value(buf_attn_score.ptr(ithr, b, pq, h), - buf_attn_w.ptr(b, h, pq)[pv], - v, - S, - p + 0, - p + 1); - } + } else { + for (size_t iwork = start; iwork < end; ++iwork) { + auto b_kv = beams ? beams.ptr(b)[pv] : b; + auto* v = present_value.ptr(b_kv, h_group, pv); + auto p = past_v_scale_zp.ptr(pv, b_kv, h_group); + for (size_t pq = 0; pq < q_len; pq++) { + for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { + attn_acc_value(buf_attn_score.ptr(ithr, b, pq, h), + buf_attn_w.ptr(b, h, pq)[pv], + v, + S, + p + 0, + p + 1); } - parallel_it_step(pv, kv_len, b, B, h_group, h_group_num); } + parallel_it_step(pv, kv_len, b, B, h_group, h_group_num); } } - }); - } + } + }); parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t pq) { auto* temp = buf_attn_score.ptr(0, b, pq, h); @@ -1431,8 +815,6 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, const ov::intel_cpu::PlainTensor& alibi_mask, const ov::intel_cpu::PlainTensor& attention_mask, const ov::intel_cpu::PlainTensor& beams, - size_t max_context_len, - const ov::intel_cpu::PlainTensor& context_lens, ov::intel_cpu::PlainTensor& output_emb, ov::intel_cpu::PlainTensor& buf_attn_w, ov::intel_cpu::PlainTensor& buf_attn_score, @@ -1450,8 +832,6 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, alibi_mask, attention_mask, beams, - max_context_len, - context_lens, output_emb, buf_attn_w, buf_attn_score, @@ -1468,8 +848,6 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, alibi_mask, attention_mask, beams, - max_context_len, - context_lens, output_emb, buf_attn_w, buf_attn_score, @@ -1488,8 +866,6 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, alibi_mask, attention_mask, beams, - max_context_len, - context_lens, output_emb, buf_attn_w, buf_attn_score, @@ -1506,8 +882,6 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, alibi_mask, attention_mask, beams, - max_context_len, - context_lens, output_emb, buf_attn_w, buf_attn_score, @@ -1524,8 +898,6 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, alibi_mask, attention_mask, beams, - max_context_len, - context_lens, output_emb, buf_attn_w, buf_attn_score, diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp index 646c4c01719529..e29e2bae0aa07a 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp @@ -21,8 +21,6 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, const ov::intel_cpu::PlainTensor& alibi_mask, const ov::intel_cpu::PlainTensor& attention_mask, const ov::intel_cpu::PlainTensor& beams, - size_t max_context_len, - const ov::intel_cpu::PlainTensor& context_lens, ov::intel_cpu::PlainTensor& output_emb, ov::intel_cpu::PlainTensor& buf_attn_w, ov::intel_cpu::PlainTensor& buf_attn_score, diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp index 0327742253c286..bdedfb02fa6096 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp @@ -516,7 +516,7 @@ inline void attn_softmax_kernel(float* a, ov::element::Type dst_precision) { using func_fp32_type = void (*)(float*, float, const float*, const float*, const uint8_t*, bool, size_t, float&); using func_bf16_type = void (*)(float*, float, const float*, const ov::bfloat16*, const uint8_t*, bool, size_t, float&); - static func_fp32_type funcs_fp32[] = { + static constexpr func_fp32_type funcs_fp32[] = { scale_add2_reduce_max, scale_add2_reduce_max, scale_add2_reduce_max, @@ -526,7 +526,7 @@ inline void attn_softmax_kernel(float* a, scale_add2_reduce_max, scale_add2_reduce_max }; - static func_bf16_type funcs_bf16[] = { + static constexpr func_bf16_type funcs_bf16[] = { scale_add2_reduce_max, scale_add2_reduce_max, scale_add2_reduce_max, diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/transpose_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/transpose_kernel.hpp new file mode 100644 index 00000000000000..b39028792ee547 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/transpose_kernel.hpp @@ -0,0 +1,254 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include "common.hpp" +#include "openvino/core/type/element_type.hpp" + +#include +#include +#include +#include + +namespace ov { +namespace Extensions { +namespace Cpu { +namespace XARCH { + +#if defined(HAVE_AVX512F) +inline void transpose_m512i_16x16(__m512i& r0, __m512i& r1, __m512i& r2, __m512i& r3, + __m512i& r4, __m512i& r5, __m512i& r6, __m512i& r7, + __m512i& r8, __m512i& r9, __m512i& ra, __m512i& rb, + __m512i& rc, __m512i& rd, __m512i& re, __m512i& rf) { + __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf; + + t0 = _mm512_unpacklo_epi32(r0, r1); // 0 16 1 17 4 20 5 21 8 24 9 25 12 28 13 29 + t1 = _mm512_unpackhi_epi32(r0, r1); // 2 18 3 19 6 22 7 23 10 26 11 27 14 30 15 31 + t2 = _mm512_unpacklo_epi32(r2, r3); // 32 48 33 49 ... + t3 = _mm512_unpackhi_epi32(r2, r3); // 34 50 35 51 ... + t4 = _mm512_unpacklo_epi32(r4, r5); // 64 80 65 81 ... + t5 = _mm512_unpackhi_epi32(r4, r5); // 66 82 67 83 ... + t6 = _mm512_unpacklo_epi32(r6, r7); // 96 112 97 113 ... + t7 = _mm512_unpackhi_epi32(r6, r7); // 98 114 99 115 ... + t8 = _mm512_unpacklo_epi32(r8, r9); // 128 ... + t9 = _mm512_unpackhi_epi32(r8, r9); // 130 ... + ta = _mm512_unpacklo_epi32(ra, rb); // 160 ... + tb = _mm512_unpackhi_epi32(ra, rb); // 162 ... + tc = _mm512_unpacklo_epi32(rc, rd); // 196 ... + td = _mm512_unpackhi_epi32(rc, rd); // 198 ... + te = _mm512_unpacklo_epi32(re, rf); // 228 ... + tf = _mm512_unpackhi_epi32(re, rf); // 230 ... + + r0 = _mm512_unpacklo_epi64(t0, t2); // 0 16 32 48 ... + r1 = _mm512_unpackhi_epi64(t0, t2); // 1 17 33 49 ... + r2 = _mm512_unpacklo_epi64(t1, t3); // 2 18 34 49 ... + r3 = _mm512_unpackhi_epi64(t1, t3); // 3 19 35 51 ... + r4 = _mm512_unpacklo_epi64(t4, t6); // 64 80 96 112 ... + r5 = _mm512_unpackhi_epi64(t4, t6); // 65 81 97 114 ... + r6 = _mm512_unpacklo_epi64(t5, t7); // 66 82 98 113 ... + r7 = _mm512_unpackhi_epi64(t5, t7); // 67 83 99 115 ... + r8 = _mm512_unpacklo_epi64(t8, ta); // 128 144 160 176 ... + r9 = _mm512_unpackhi_epi64(t8, ta); // 129 145 161 178 ... + ra = _mm512_unpacklo_epi64(t9, tb); // 130 146 162 177 ... + rb = _mm512_unpackhi_epi64(t9, tb); // 131 147 163 179 ... + rc = _mm512_unpacklo_epi64(tc, te); // 192 208 228 240 ... + rd = _mm512_unpackhi_epi64(tc, te); // 193 209 229 241 ... + re = _mm512_unpacklo_epi64(td, tf); // 194 210 230 242 ... + rf = _mm512_unpackhi_epi64(td, tf); // 195 211 231 243 ... + + t0 = _mm512_shuffle_i32x4(r0, r4, 0x88); // 0 16 32 48 8 24 40 56 64 80 96 112 ... + t1 = _mm512_shuffle_i32x4(r1, r5, 0x88); // 1 17 33 49 ... + t2 = _mm512_shuffle_i32x4(r2, r6, 0x88); // 2 18 34 50 ... + t3 = _mm512_shuffle_i32x4(r3, r7, 0x88); // 3 19 35 51 ... + t4 = _mm512_shuffle_i32x4(r0, r4, 0xdd); // 4 20 36 52 ... + t5 = _mm512_shuffle_i32x4(r1, r5, 0xdd); // 5 21 37 53 ... + t6 = _mm512_shuffle_i32x4(r2, r6, 0xdd); // 6 22 38 54 ... + t7 = _mm512_shuffle_i32x4(r3, r7, 0xdd); // 7 23 39 55 ... + t8 = _mm512_shuffle_i32x4(r8, rc, 0x88); // 128 144 160 176 ... + t9 = _mm512_shuffle_i32x4(r9, rd, 0x88); // 129 145 161 177 ... + ta = _mm512_shuffle_i32x4(ra, re, 0x88); // 130 146 162 178 ... + tb = _mm512_shuffle_i32x4(rb, rf, 0x88); // 131 147 163 179 ... + tc = _mm512_shuffle_i32x4(r8, rc, 0xdd); // 132 148 164 180 ... + td = _mm512_shuffle_i32x4(r9, rd, 0xdd); // 133 149 165 181 ... + te = _mm512_shuffle_i32x4(ra, re, 0xdd); // 134 150 166 182 ... + tf = _mm512_shuffle_i32x4(rb, rf, 0xdd); // 135 151 167 183 ... + + r0 = _mm512_shuffle_i32x4(t0, t8, 0x88); // 0 16 32 48 64 80 96 112 ... 240 + r1 = _mm512_shuffle_i32x4(t1, t9, 0x88); // 1 17 33 49 66 81 97 113 ... 241 + r2 = _mm512_shuffle_i32x4(t2, ta, 0x88); // 2 18 34 50 67 82 98 114 ... 242 + r3 = _mm512_shuffle_i32x4(t3, tb, 0x88); // 3 19 35 51 68 83 99 115 ... 243 + r4 = _mm512_shuffle_i32x4(t4, tc, 0x88); // 4 ... + r5 = _mm512_shuffle_i32x4(t5, td, 0x88); // 5 ... + r6 = _mm512_shuffle_i32x4(t6, te, 0x88); // 6 ... + r7 = _mm512_shuffle_i32x4(t7, tf, 0x88); // 7 ... + r8 = _mm512_shuffle_i32x4(t0, t8, 0xdd); // 8 ... + r9 = _mm512_shuffle_i32x4(t1, t9, 0xdd); // 9 ... + ra = _mm512_shuffle_i32x4(t2, ta, 0xdd); // 10 ... + rb = _mm512_shuffle_i32x4(t3, tb, 0xdd); // 11 ... + rc = _mm512_shuffle_i32x4(t4, tc, 0xdd); // 12 ... + rd = _mm512_shuffle_i32x4(t5, td, 0xdd); // 13 ... + re = _mm512_shuffle_i32x4(t6, te, 0xdd); // 14 ... + rf = _mm512_shuffle_i32x4(t7, tf, 0xdd); // 15 31 47 63 79 96 111 127 ... 255 +} + +template +inline void transpose_16x16_kernel(float* _dst, T* src, size_t dst_stride, size_t src_stride) { + auto* dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + r0 = _mm512_castps_si512(mm512_uni_loadu_ps(src)); + r1 = _mm512_castps_si512(mm512_uni_loadu_ps(src + src_stride)); + r2 = _mm512_castps_si512(mm512_uni_loadu_ps(src + 2 * src_stride)); + r3 = _mm512_castps_si512(mm512_uni_loadu_ps(src + 3 * src_stride)); + r4 = _mm512_castps_si512(mm512_uni_loadu_ps(src + 4 * src_stride)); + r5 = _mm512_castps_si512(mm512_uni_loadu_ps(src + 5 * src_stride)); + r6 = _mm512_castps_si512(mm512_uni_loadu_ps(src + 6 * src_stride)); + r7 = _mm512_castps_si512(mm512_uni_loadu_ps(src + 7 * src_stride)); + r8 = _mm512_castps_si512(mm512_uni_loadu_ps(src + 8 * src_stride)); + r9 = _mm512_castps_si512(mm512_uni_loadu_ps(src + 9 * src_stride)); + ra = _mm512_castps_si512(mm512_uni_loadu_ps(src + 10 * src_stride)); + rb = _mm512_castps_si512(mm512_uni_loadu_ps(src + 11 * src_stride)); + rc = _mm512_castps_si512(mm512_uni_loadu_ps(src + 12 * src_stride)); + rd = _mm512_castps_si512(mm512_uni_loadu_ps(src + 13 * src_stride)); + re = _mm512_castps_si512(mm512_uni_loadu_ps(src + 14 * src_stride)); + rf = _mm512_castps_si512(mm512_uni_loadu_ps(src + 15 * src_stride)); + + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + + _mm512_storeu_si512(dst, r0); + _mm512_storeu_si512(dst + dst_stride, r1); + _mm512_storeu_si512(dst + 2 * dst_stride, r2); + _mm512_storeu_si512(dst + 3 * dst_stride, r3); + _mm512_storeu_si512(dst + 4 * dst_stride, r4); + _mm512_storeu_si512(dst + 5 * dst_stride, r5); + _mm512_storeu_si512(dst + 6 * dst_stride, r6); + _mm512_storeu_si512(dst + 7 * dst_stride, r7); + _mm512_storeu_si512(dst + 8 * dst_stride, r8); + _mm512_storeu_si512(dst + 9 * dst_stride, r9); + _mm512_storeu_si512(dst + 10 * dst_stride, ra); + _mm512_storeu_si512(dst + 11 * dst_stride, rb); + _mm512_storeu_si512(dst + 12 * dst_stride, rc); + _mm512_storeu_si512(dst + 13 * dst_stride, rd); + _mm512_storeu_si512(dst + 14 * dst_stride, re); + _mm512_storeu_si512(dst + 15 * dst_stride, rf); +} + +inline void transpose_16x16_kernel(uint32_t* dst, uint32_t* src, size_t dst_stride, size_t src_stride) { + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + r0 = _mm512_loadu_si512(src); + r1 = _mm512_loadu_si512(src + src_stride); + r2 = _mm512_loadu_si512(src + 2 * src_stride); + r3 = _mm512_loadu_si512(src + 3 * src_stride); + r4 = _mm512_loadu_si512(src + 4 * src_stride); + r5 = _mm512_loadu_si512(src + 5 * src_stride); + r6 = _mm512_loadu_si512(src + 6 * src_stride); + r7 = _mm512_loadu_si512(src + 7 * src_stride); + r8 = _mm512_loadu_si512(src + 8 * src_stride); + r9 = _mm512_loadu_si512(src + 9 * src_stride); + ra = _mm512_loadu_si512(src + 10 * src_stride); + rb = _mm512_loadu_si512(src + 11 * src_stride); + rc = _mm512_loadu_si512(src + 12 * src_stride); + rd = _mm512_loadu_si512(src + 13 * src_stride); + re = _mm512_loadu_si512(src + 14 * src_stride); + rf = _mm512_loadu_si512(src + 15 * src_stride); + + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + + _mm512_storeu_si512(dst, r0); + _mm512_storeu_si512(dst + dst_stride, r1); + _mm512_storeu_si512(dst + 2 * dst_stride, r2); + _mm512_storeu_si512(dst + 3 * dst_stride, r3); + _mm512_storeu_si512(dst + 4 * dst_stride, r4); + _mm512_storeu_si512(dst + 5 * dst_stride, r5); + _mm512_storeu_si512(dst + 6 * dst_stride, r6); + _mm512_storeu_si512(dst + 7 * dst_stride, r7); + _mm512_storeu_si512(dst + 8 * dst_stride, r8); + _mm512_storeu_si512(dst + 9 * dst_stride, r9); + _mm512_storeu_si512(dst + 10 * dst_stride, ra); + _mm512_storeu_si512(dst + 11 * dst_stride, rb); + _mm512_storeu_si512(dst + 12 * dst_stride, rc); + _mm512_storeu_si512(dst + 13 * dst_stride, rd); + _mm512_storeu_si512(dst + 14 * dst_stride, re); + _mm512_storeu_si512(dst + 15 * dst_stride, rf); +} + +#elif defined(HAVE_AVX2) + +// https://stackoverflow.com/questions/25622745/transpose-an-8x8-float-using-avx-avx2 +inline void transpose_8x8(__m256& r0, __m256& r1, __m256& r2, __m256& r3, __m256& r4, __m256& r5, __m256& r6, __m256& r7) { + __m256 t0, t1, t2, t3, t4, t5, t6, t7; + __m256 tt0, tt1, tt2, tt3, tt4, tt5, tt6, tt7; + t0 = _mm256_unpacklo_ps(r0, r1); + t1 = _mm256_unpackhi_ps(r0, r1); + t2 = _mm256_unpacklo_ps(r2, r3); + t3 = _mm256_unpackhi_ps(r2, r3); + t4 = _mm256_unpacklo_ps(r4, r5); + t5 = _mm256_unpackhi_ps(r4, r5); + t6 = _mm256_unpacklo_ps(r6, r7); + t7 = _mm256_unpackhi_ps(r6, r7); + tt0 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(1, 0, 1, 0)); + tt1 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(3, 2, 3, 2)); + tt2 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(1, 0, 1, 0)); + tt3 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(3, 2, 3, 2)); + tt4 = _mm256_shuffle_ps(t4, t6, _MM_SHUFFLE(1, 0, 1, 0)); + tt5 = _mm256_shuffle_ps(t4, t6, _MM_SHUFFLE(3, 2, 3, 2)); + tt6 = _mm256_shuffle_ps(t5, t7, _MM_SHUFFLE(1, 0, 1, 0)); + tt7 = _mm256_shuffle_ps(t5, t7, _MM_SHUFFLE(3, 2, 3, 2)); + r0 = _mm256_permute2f128_ps(tt0, tt4, 0x20); + r1 = _mm256_permute2f128_ps(tt1, tt5, 0x20); + r2 = _mm256_permute2f128_ps(tt2, tt6, 0x20); + r3 = _mm256_permute2f128_ps(tt3, tt7, 0x20); + r4 = _mm256_permute2f128_ps(tt0, tt4, 0x31); + r5 = _mm256_permute2f128_ps(tt1, tt5, 0x31); + r6 = _mm256_permute2f128_ps(tt2, tt6, 0x31); + r7 = _mm256_permute2f128_ps(tt3, tt7, 0x31); +} + +template +inline void transpose_16x16_kernel(float* dst, T* src, size_t dst_stride, size_t src_stride) { + __m256 r0, r1, r2, r3, r4, r5, r6, r7; + + for (int i = 0; i < 16; i += 8) { + for (int j = 0; j < 16; j += 8) { + r0 = mm256_uni_loadu_ps(src + src_stride * j); + r1 = mm256_uni_loadu_ps(src + src_stride * (1 + j)); + r2 = mm256_uni_loadu_ps(src + src_stride * (2 + j)); + r3 = mm256_uni_loadu_ps(src + src_stride * (3 + j)); + r4 = mm256_uni_loadu_ps(src + src_stride * (4 + j)); + r5 = mm256_uni_loadu_ps(src + src_stride * (5 + j)); + r6 = mm256_uni_loadu_ps(src + src_stride * (6 + j)); + r7 = mm256_uni_loadu_ps(src + src_stride * (7 + j)); + + transpose_8x8(r0, r1, r2, r3, r4, r5, r6, r7); + + _mm256_storeu_ps(dst + j, r0); + _mm256_storeu_ps(dst + j + dst_stride, r1); + _mm256_storeu_ps(dst + j + dst_stride * 2, r2); + _mm256_storeu_ps(dst + j + dst_stride * 3, r3); + _mm256_storeu_ps(dst + j + dst_stride * 4, r4); + _mm256_storeu_ps(dst + j + dst_stride * 5, r5); + _mm256_storeu_ps(dst + j + dst_stride * 6, r6); + _mm256_storeu_ps(dst + j + dst_stride * 7, r7); + } + src += 8; + dst += 8 * dst_stride; + } +} + +#else + +template +inline void transpose_16x16_kernel(TDST* dst, TSRC* src, size_t dst_stride, size_t src_stride) { + for (size_t i = 0; i < 16; i++) { + for (size_t j = 0; j < 16; j++) { + dst[i * dst_stride + j] = static_cast(src[i + j * src_stride]); + } + } +} + +#endif + +} // namespace XARCH +} // namespace Cpu +} // namespace Extensions +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp index a1aeaaecf9ae17..86f80b33a8c875 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp @@ -24,7 +24,8 @@ BrgemmKernel::BrgemmKernel(size_t M, size_t ldb, size_t ldc, bool b_transposed, - ov::element::Type inType) + ov::element::Type inType, + bool b_accumulate) : M(M), K(K), N(N), @@ -32,7 +33,8 @@ BrgemmKernel::BrgemmKernel(size_t M, ldb(ldb), ldc(ldc), b_transposed(b_transposed), - inType(inType) { + inType(inType), + b_accumulate(b_accumulate) { // blocking M M_blk = matmulOptimalM; M_tail = M % M_blk; @@ -45,7 +47,11 @@ BrgemmKernel::BrgemmKernel(size_t M, THROW_ERROR("brgemm bf16 kernel could only be used above avx512_bf16"); bool isAMXSupported = is_bf16 && mayiuse(avx512_core_amx); - size_t vlen = cpu_isa_traits::vlen; + size_t vlen; + if (mayiuse(avx512_core)) + vlen = cpu_isa_traits::vlen; + else + vlen = cpu_isa_traits::vlen; // blocking N N_blk = is_bf16 ? 32 : std::max(N, vlen / inType.size()); N_tail = N % N_blk; @@ -67,7 +73,7 @@ BrgemmKernel::BrgemmKernel(size_t M, auto M_ = m ? M_tail : M < M_blk ? 0 : M_blk; auto N_ = n ? N_tail : N - N_tail; auto K_ = k ? K_tail : K - K % K_blk; - auto beta = k && brgCtxs[getBrgIdx(m, 0, n)].K != 0 ? 1.0f : 0.0f; + auto beta = (b_accumulate || (k && brgCtxs[getBrgIdx(m, 0, n)].K != 0)) ? 1.0f : 0.0f; brgemmCtx.M = M_; brgemmCtx.N = N_; @@ -134,9 +140,14 @@ void BrgemmKernel::init_brgemm(brgemmCtx& ctx, const bool is_int8 = one_of(ctx.dt_in0, data_type::u8, data_type::s8) && one_of(ctx.dt_in1, data_type::u8, data_type::s8); - auto isa = use_amx ? isa_undef - : ctx.dt_in0 == dnnl_data_type_t::dnnl_bf16 ? avx512_core_bf16 - : (is_int8 ? avx512_core_vnni : avx512_core); + cpu_isa_t isa; + if (mayiuse(avx512_core)) { + isa = use_amx ? isa_undef + : ctx.dt_in0 == dnnl_data_type_t::dnnl_bf16 ? avx512_core_bf16 + : (is_int8 ? avx512_core_vnni : avx512_core); + } else { + isa = cpu_isa_t::avx2; + } auto status = brgemm_desc_init(&brgDesc, isa, brgemm_addr, @@ -158,6 +169,20 @@ void BrgemmKernel::init_brgemm(brgemmCtx& ctx, THROW_ERROR("cannot be executed due to invalid brgconv params"); } + if (use_amx && b_accumulate) { + brgemm_attr_t brgattr; + brgattr.max_bs = 1; + brgattr.wary_tail_read = false; + brgattr.hint_innermost_loop = brgemm_innermost_undef; + // if b_accumulate is true, it means we want c+=a*b. jit_brgemm_amx_uker_base_t::load_accumulators can support this using tileload(c) without postops + brgattr.use_uker = true; + brgattr.use_interleave_stores = true; + brgattr.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1; + if (brgemm_desc_set_attr(&brgDesc, brgattr) != dnnl_success) { + THROW_ERROR("cannot be executed due to brgemm_desc_set_attr failed"); + } + } + ctx.is_with_amx = use_amx; status = brgemm_init_tiles(brgDesc, ctx.palette); if (use_amx) { @@ -319,7 +344,7 @@ void BrgemmKernel::executeGemm(bool is_M_tail, void* a, void* b, void* c, void* for (size_t k = 0; k < 2; k++) { size_t mIdx = is_M_tail ? 1 : 0; auto& brgemmCtx = brgCtxs[getBrgIdx(mIdx, k, n)]; - if (brgemmCtx.K != 0 && brgemmCtx.N != 0) { + if (brgemmCtx.K != 0 && brgemmCtx.N != 0 && brgemmCtx.M != 0) { auto local_a_ptr = k > 0 ? ptr_a_tail : ptr_A; auto B_stride = (k * count_K + n * count_N * brgVnniFactor) * inType.size(); auto weight_ptr = ptr_scartch_b + B_stride; diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.hpp index 7ba637722dc184..513b484ab0b963 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.hpp @@ -28,7 +28,8 @@ class BrgemmKernel { size_t ldb, size_t ldc, bool b_transposed = false, - ov::element::Type inType = ov::element::bf16); + ov::element::Type inType = ov::element::bf16, + bool b_accumulate = false); // execute all M void executeGemm(void* a, void* b, void* c, void* wsp, void* scratch_a, void* scratch_b); // execute by m_blk @@ -58,6 +59,7 @@ class BrgemmKernel { size_t packedBSize = 0; size_t packedASize = 0; ov::element::Type inType; + bool b_accumulate = false; static constexpr size_t MHA_BRGEMM_KERNELS_NUM = 8; static constexpr size_t matmulOptimalM = 32; struct brgemmCtx { diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp new file mode 100644 index 00000000000000..184ab10b852ba1 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -0,0 +1,216 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "paged_attn.h" + +#include "common/arbitrary_order_desc_creator.h" +#include "common/primitive_hashing_utils.hpp" +#include "cpu/x64/cpu_isa_traits.hpp" +#include "dnnl_extension_utils.h" +#include "memory_desc/cpu_memory_desc_utils.h" +#include "memory_desc/dnnl_blocked_memory_desc.h" +#include "onednn/dnnl.h" +#include "openvino/core/parallel.hpp" +#include "openvino/util/common_util.hpp" +#include "shape_inference/custom/paged_attn.hpp" +#include "shape_inference/shape_inference_internal_dyn.hpp" + +#include "utils/plain_tensor.hpp" +#include "kernels/scaled_attn/executor_pa.hpp" +#include "kernels/scaled_attn/attn_memcpy.hpp" +#include "kernels/scaled_attn/attn_quant.hpp" + +#include +#include +#include + +using namespace ov::Extensions::Cpu; +using namespace ov::Extensions::Cpu::XARCH; +using namespace dnnl::impl; +using namespace dnnl::impl::cpu::x64; + +namespace ov { +namespace intel_cpu { +namespace node { + +struct PagedAttentionKey { + ov::element::Type rtPrecision; + + size_t hash() const; + bool operator==(const PagedAttentionKey& rhs) const; +}; + +size_t PagedAttentionKey::hash() const { + size_t seed = 0; + seed = hash_combine(seed, rtPrecision.hash()); + + return seed; +} + +bool PagedAttentionKey::operator==(const PagedAttentionKey& rhs) const { + auto retVal = rtPrecision == rhs.rtPrecision; + + return retVal; +} + +PagedAttention::PagedAttention(const std::shared_ptr& op, const GraphContext::CPtr context) + : Node(op, context, PAShapeInferFactory(op)) { + std::string errorMessage; + if (!isSupportedOperation(op, errorMessage)) { + OPENVINO_THROW("CPU: " + errorMessage); + } +} + +void PagedAttention::initSupportedPrimitiveDescriptors() { + if (!supportedPrimitiveDescriptors.empty()) + return; + auto rtPrecision = getRuntimePrecision(); + + NodeConfig config; + auto& creatorsMap = BlockedDescCreator::getCommonCreators(); + auto orgInputNumber = getOriginalInputsNumber(); + config.inConfs.resize(orgInputNumber); + config.outConfs.resize(getOriginalOutputsNumber()); + config.inConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + rtPrecision, getInputShapeAtPort(0))); + config.inConfs[1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + rtPrecision, getInputShapeAtPort(1))); + config.inConfs[2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + rtPrecision, getInputShapeAtPort(2))); + + OPENVINO_ASSERT(orgInputNumber == 13 || orgInputNumber == 14, "The input number of PagedAttention should be 13 or 14."); + // kvcache, float, [] + auto past_kv_input_mem_precision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); + config.inConfs[PagedAttentionExecutor::ID_KCACHE].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + past_kv_input_mem_precision, getInputShapeAtPort(PagedAttentionExecutor::ID_KCACHE))); + config.inConfs[PagedAttentionExecutor::ID_VCACHE].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + past_kv_input_mem_precision, getInputShapeAtPort(PagedAttentionExecutor::ID_VCACHE))); + // is_prompt, bool, [] + config.inConfs[PagedAttentionExecutor::ID_IS_PROMPT].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::u8, getInputShapeAtPort(PagedAttentionExecutor::ID_IS_PROMPT))); + // slot_mapping, int, [batch_size, max_context_len] + config.inConfs[PagedAttentionExecutor::ID_SLOT_MAPPING].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::i32, getInputShapeAtPort(PagedAttentionExecutor::ID_SLOT_MAPPING))); + // max_context_len, int, [] + config.inConfs[PagedAttentionExecutor::ID_MAX_CONTEXT_LEN].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::i32, getInputShapeAtPort(PagedAttentionExecutor::ID_MAX_CONTEXT_LEN))); + // context_lens, int, [batch_size] + config.inConfs[PagedAttentionExecutor::ID_CONTEXT_LENS].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::i32, getInputShapeAtPort(PagedAttentionExecutor::ID_CONTEXT_LENS))); + // block_tables, int, [batch_size, max_block_per_request] + config.inConfs[PagedAttentionExecutor::ID_BLOCK_TABLES].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::i32, getInputShapeAtPort(PagedAttentionExecutor::ID_BLOCK_TABLES))); + // scale, float, [] + config.inConfs[PagedAttentionExecutor::ID_SCALE].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::f32, getInputShapeAtPort(PagedAttentionExecutor::ID_SCALE))); + // alibi_slopes, float, [?] or nullptr + config.inConfs[PagedAttentionExecutor::ID_ALIBI_SLOPES].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::f32, getInputShapeAtPort(PagedAttentionExecutor::ID_ALIBI_SLOPES))); + // sliding_window, int, [] + config.inConfs[PagedAttentionExecutor::ID_SLIDING_WINDOW].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::i32, getInputShapeAtPort(PagedAttentionExecutor::ID_SLIDING_WINDOW))); + if (orgInputNumber == 14) { + // subsequence_lens, int, [batch_size] + config.inConfs[PagedAttentionExecutor::ID_SUBSEQUENCE_LENS].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::i32, getInputShapeAtPort(PagedAttentionExecutor::ID_SUBSEQUENCE_LENS))); + } + + config.outConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + rtPrecision, getOutputShapeAtPort(0))); + + supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::ref_any); +} + +void PagedAttention::createPrimitive() { + auto rtPrecision = getRuntimePrecision(); + + // in one model, kvCachePrecision could not be changed so no need to care whether it may be changed. + PagedAttentionKey key = {rtPrecision}; + + auto builder = [&](const PagedAttentionKey& key) -> std::shared_ptr { +#ifdef OPENVINO_ARCH_X86_64 + auto kvCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); + return make_pa_executor(rtPrecision, kvCachePrecision); +#else + return nullptr; +#endif + }; + + auto cache = context->getParamsCache(); + auto result = cache->getOrCreate(key, builder); + if (!result.first) { + OPENVINO_THROW("PagedAttention AttentionExecutor creation fails with precision " + rtPrecision.to_string()); + } + m_executor = result.first; +} + +void PagedAttention::execute(dnnl::stream strm) { + auto orginInputNumber = getOriginalInputsNumber(); + std::vector inputs(orginInputNumber); + auto output = getDstMemoryAtPort(0); + for (size_t i = 0; i < orginInputNumber; i++) { + inputs[i] = getSrcMemoryAtPort(i); + } + + gatherConcatPastkvForPagedAttn(inputs); + + m_executor->execute(inputs, output); +} + +bool PagedAttention::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { + try { + int orgInput = static_cast(op->get_input_size()); + if (op->get_type_name() == std::string("PagedAttentionExtension") && orgInput == PagedAttentionExecutor::ID_SLIDING_WINDOW + 1) { + return true; + } + } catch (...) { + return false; + } + return true; +} + +void PagedAttention::gatherConcatPastkvForPagedAttn(const std::vector& inputs) { + PlainTensor k, v, k_cache, v_cache, slot_mapping; + + k.reset(inputs[PagedAttentionExecutor::ID_K]); // [B, L1, H * S] + v.reset(inputs[PagedAttentionExecutor::ID_V]); + k_cache.reset(inputs[PagedAttentionExecutor::ID_KCACHE]); // [NUM_BLOCKS, H, 32, S] + v_cache.reset(inputs[PagedAttentionExecutor::ID_VCACHE]); // [NUM_BLOCKS, H, 32, S] + slot_mapping.reset(inputs[PagedAttentionExecutor::ID_SLOT_MAPPING]); // [B, max_context_len] + + auto B = k.size(0); + auto L1 = k.size(1); + auto H = k_cache.size(1); + auto S = v_cache.size(3) - (k_cache.m_dt == ov::element::Type_t::u8 ? 8 : 0); + + k.assert_dims({B, L1, H * S}); + v.assert_dims({B, L1, H * S}); + slot_mapping.assert_dims({B, 0}, true); + k = k.reshape({B, L1, H, S}).permute({0, 2, 1, 3}); + v = v.reshape({B, L1, H, S}).permute({0, 2, 1, 3}); + if (k_cache.m_dt == ov::element::Type_t::u8) { + k_cache.assert_dims({0, H, 0, S + 8}, true); + v_cache.assert_dims({k_cache.m_dims[0], H, k_cache.m_dims[2], S + 8}); + paged_attn_quantkv(k, v, k_cache, v_cache, slot_mapping); + } else { + k_cache.assert_dims({0, H, 0, S}, true); + v_cache.assert_dims({k_cache.m_dims[0], H, k_cache.m_dims[2], S}); + paged_attn_memcpy(k, v, k_cache, v_cache, slot_mapping); + } +} + +ov::element::Type PagedAttention::getRuntimePrecision() const { + auto rtPrecision = getOriginalInputPrecisionAtPort(0); + // bf16 should be enabled only when platform supports + if (rtPrecision == ov::element::bf16 && ov::with_cpu_x86_bfloat16()) { + rtPrecision = ov::element::bf16; + } else { + rtPrecision = ov::element::f32; + } + return rtPrecision; +} + +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.h b/src/plugins/intel_cpu/src/nodes/paged_attn.h new file mode 100644 index 00000000000000..91e306626b5a80 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.h @@ -0,0 +1,51 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "memory_state.h" +#include "node.h" +#include "transformations/cpu_opset/common/op/sdpa.hpp" +#include "utils/plain_tensor.hpp" +#include "kernels/scaled_attn/executor_pa.hpp" + +namespace ov { +namespace intel_cpu { +namespace node { + +class PagedAttention : public Node { +public: + PagedAttention(const std::shared_ptr& op, const GraphContext::CPtr context); + + void getSupportedDescriptors() override {} + bool created() const override { + return getType() == Type::PagedAttention; + } + // pastkv may have zero dimension + bool isExecutable() const override { + return !isInputTensorAtPortEmpty(0) && !isInputTensorAtPortEmpty(1) && !isInputTensorAtPortEmpty(2); + } + bool needPrepareParams() const override { + return false; + } + void executeDynamicImpl(dnnl::stream strm) override { + execute(strm); + } + void initSupportedPrimitiveDescriptors() override; + void execute(dnnl::stream strm) override; + void createPrimitive() override; + static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; + +private: + void gatherConcatPastkvForPagedAttn(const std::vector& inputs); + ov::element::Type getRuntimePrecision() const override; + + std::shared_ptr m_executor; + template struct AttentionExecutor; + friend struct PagedAttentionKey; +}; + +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index 469850a94009b8..bdbb25505ca647 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -63,141 +63,6 @@ bool ScaledDotProductAttentionKey::operator==(const ScaledDotProductAttentionKey return retVal; } -#ifdef OPENVINO_ARCH_X86_64 - -// w = query * Key -// -// query: [1, head_size] -// Key : [block_size, head_size] -// w : [1, block_size] -// -// head_size is known at compile time -struct TileConfig { - uint8_t palette_id; - uint8_t startRow; - uint8_t reserved[14]; - uint16_t cols[16]; - uint8_t rows[16]; - void reset(int palette, int _startRow, const std::vector>& _rows_columnsBytes) { - palette_id = palette; - startRow = _startRow; - unsigned long i; - for (i = 0; i < 14; i++) { - reserved[i] = 0; - } - for (i = 0; i < _rows_columnsBytes.size(); i++) { - rows[i] = _rows_columnsBytes[i].first; - cols[i] = _rows_columnsBytes[i].second; - } - for (; i < 16; i++) { - cols[i] = 0; - rows[i] = 0; - } - } -}; - -class TileConfiger : public jit_generator { -public: - DECLARE_CPU_JIT_AUX_FUNCTIONS(TileConfiger) - TileConfiger() : jit_generator(jit_name()) { - create_kernel(); - } - void generate() override { - Xbyak::Label release; - test(abi_param1, abi_param1); - jz(release); - ldtilecfg(ptr[abi_param1]); - ret(); - L(release); - tilerelease(); - ret(); - } -}; - -class JitMatMulVecAMX : public jit_generator { - void operator=(const JitMatMulVecAMX&); - -public: - DECLARE_CPU_JIT_AUX_FUNCTIONS(JitMatMulVecAMX) - int m_head_size; - int m_block_size; - TileConfiger m_tile_configer; - TileConfig m_tile_cfg; - JitMatMulVecAMX(int head_size, int block_size) : jit_generator(jit_name()), m_head_size(head_size), m_block_size(block_size) { - create_kernel(); - m_tile_cfg.reset(1, - 0, - { - {16, 4}, // C:0 M x 1 (4b) - {16, 64}, // A:1 M x 32/64 (64b) - {16, 4}, // B:2 32/64 x 1 (4b) - {16, 4}, // B:3 - {16, 4}, // B:4 - {16, 4}, // B:5 - {16, 4}, // B:6 - {16, 4}, // B:7 - }); - } - - void tile_config() { - m_tile_configer(&m_tile_cfg); - } - void tile_release() { - m_tile_configer(nullptr); - } - - // to save push/pop: do not use `abi_save_gpr_regs` - Xbyak::Reg64 reg_q_addr = abi_param1; - Xbyak::Reg64 reg_k_addr = abi_param2; - Xbyak::Reg64 reg_dst_addr = abi_param3; - Xbyak::Reg64 reg_stride_A = rax; - Xbyak::Reg64 reg_stride_BC = r9; - - Xbyak::Tmm tmmC = tmm0; - Xbyak::Tmm tmmA = tmm1; - Xbyak::Tmm tmmB0 = tmm2; - Xbyak::Tmm tmmB1 = tmm3; - Xbyak::Tmm tmmB2 = tmm4; - Xbyak::Tmm tmmB3 = tmm5; - Xbyak::Tmm tmmB4 = tmm6; - Xbyak::Tmm tmmB5 = tmm7; - - void generate() override { - mov(reg_stride_A, m_head_size * 2); - mov(reg_stride_BC, 4); - const int kStep = 32; - if ((m_head_size % 32) != 0) - throw std::runtime_error("head size is not multiple of 32"); - if ((m_block_size % 16) != 0) - throw std::runtime_error("block size is not multiple of 16"); - auto num_B_tiles = m_head_size / kStep; - if (num_B_tiles > 6) - throw std::runtime_error("number of B tiles is bigger than 6"); - - /* - B(query) head_size x 1 - A(key) matrix : block_size x head_size C(dst) block_size x 1 - */ - // load query into B tiles - for (int i = 0; i < num_B_tiles; i++) { - tileloadd(Xbyak::Tmm(tmmB0.getIdx() + i), ptr[reg_q_addr + reg_stride_BC + i * 64]); - } - - for (int m = 0; m < m_block_size; m += 16) { - tilezero(tmmC); - for (int i = 0; i < num_B_tiles; i++) { - tileloadd(tmmA, ptr[reg_k_addr + reg_stride_A + i * 64]); - tdpbf16ps(tmmC, tmmA, Xbyak::Tmm(tmmB0.getIdx() + i)); - } - tilestored(ptr[reg_dst_addr + reg_stride_BC + m * sizeof(float)], tmmC); - add(reg_k_addr, m_head_size * 2 * 16); - } - ret(); - } -}; - -#endif - // default implementation: reference template struct MHAKernel { @@ -260,8 +125,7 @@ struct MHAKernel { PlainTensor& output_emb, bool has_out_transpose, bool auto_causal, - float d_scale = 0.0f, - size_t sliding_window = 0) { + float d_scale = 0.0f) { auto B = query.size(0); auto H = query.size(1); auto q_len = query.size(2); @@ -313,18 +177,7 @@ struct MHAKernel { } // softmax - if (sliding_window) { - size_t start_idx = 0; - auto new_causal = ncausal; - if (ncausal > sliding_window) { - start_idx = ncausal - static_cast(sliding_window); - new_causal = sliding_window; - } - softmax(&attn_score[start_idx], new_causal); - memset(&attn_score[0], 0, sizeof(float) * start_idx); - } else { - softmax(&attn_score[0], ncausal); - } + softmax(&attn_score[0], ncausal); // linearly combine value word_vec.assign(head_size, 0.0f); @@ -497,8 +350,7 @@ struct MHAKernel { PlainTensor& output_emb, bool has_out_transpose, bool auto_causal, - float d_scale = 0.0f, - size_t sliding_window = 0) { + float d_scale = 0.0f) { const auto B = query.size(0); const auto H = query.size(1); const auto q_len = query.size(2); @@ -561,39 +413,17 @@ struct MHAKernel { for (size_t m = m_start; m < m_end; m++) { // apply attention mask & sofmax auto ncausal = auto_causal ? (kv_len - q_len + m + 1) : kv_len; - if (sliding_window) { - size_t start_idx = 0; - auto new_causal = ncausal; - if (ncausal > sliding_window) { - start_idx = ncausal - static_cast(sliding_window); - new_causal = sliding_window; - } - attn_softmax(&score.at({b, h, m, start_idx}), - &weight.at({b, h, m, start_idx}), - d_scale, - alibi_ptr + m * alibi_stride, - attn_mask_ptr + m * attn_mask_stride, - cmask_ptr + m * cmask_stride, - select_nfltmax_at_0, - new_causal, - kv_len - start_idx, - precision_of::value, - precision_of::value); - - memset(&weight.at({b, h, m, 0}), 0, sizeof(T) * start_idx); - } else { - attn_softmax(&score.at({b, h, m, 0}), - &weight.at({b, h, m, 0}), - d_scale, - alibi_ptr + m * alibi_stride, - attn_mask_ptr + m * attn_mask_stride, - cmask_ptr + m * cmask_stride, - select_nfltmax_at_0, - ncausal, - kv_len, - precision_of::value, - precision_of::value); - } + attn_softmax(&score.at({b, h, m, 0}), + &weight.at({b, h, m, 0}), + d_scale, + alibi_ptr + m * alibi_stride, + attn_mask_ptr + m * attn_mask_stride, + cmask_ptr + m * cmask_stride, + select_nfltmax_at_0, + ncausal, + kv_len, + precision_of::value, + precision_of::value); } T* w_ptr = &weight.at({b, h, m_start, 0}); float* fp32_out_ptr; @@ -657,8 +487,7 @@ struct MHAKernel { PlainTensor& output_emb, bool has_out_transpose, bool auto_causal, - float d_scale = 0.0f, - size_t sliding_window = 0) { + float d_scale = 0.0f) { auto head_size = query.size(3); if (d_scale == 0.0f) d_scale = 1.0f / sqrt(head_size); @@ -672,8 +501,7 @@ struct MHAKernel { output_emb, has_out_transpose, auto_causal, - d_scale, - sliding_window); + d_scale); } }; @@ -715,8 +543,7 @@ struct MHAKernel { PlainTensor& output_emb, bool has_out_transpose, bool auto_causal, - float d_scale = 0.0f, - size_t sliding_window = 0) { + float d_scale = 0.0f) { auto B = query.size(0); auto H = query.size(1); auto q_len = query.size(2); @@ -806,39 +633,17 @@ struct MHAKernel { for (size_t m = m_start; m < m_end; m++) { // apply attention mask & sofmax auto ncausal = auto_causal ? (kv_len - q_len + m + 1) : kv_len; - if (sliding_window) { - size_t start_idx = 0; - auto new_causal = ncausal; - if (ncausal > sliding_window) { - start_idx = ncausal - static_cast(sliding_window); - new_causal = sliding_window; - } - attn_softmax(qk + (m - m_start) * qk_m_stride + start_idx, - qk + (m - m_start) * qk_m_stride + start_idx, - d_scale, - alibi_ptr + m * alibi_stride, - attn_mask_ptr + m * attn_mask_stride, - cmask_ptr + m * cmask_stride, - select_nfltmax_at_0, - new_causal, - kv_len - start_idx, - ov::element::f32, - ov::element::f32); - - memset(qk + (m - m_start) * qk_m_stride, 0, sizeof(float) * start_idx); - } else { - attn_softmax(qk + (m - m_start) * qk_m_stride, - qk + (m - m_start) * qk_m_stride, - d_scale, - alibi_ptr + m * alibi_stride, - attn_mask_ptr + m * attn_mask_stride, - cmask_ptr + m * cmask_stride, - select_nfltmax_at_0, - ncausal, - kv_len, - ov::element::f32, - ov::element::f32); - } + attn_softmax(qk + (m - m_start) * qk_m_stride, + qk + (m - m_start) * qk_m_stride, + d_scale, + alibi_ptr + m * alibi_stride, + attn_mask_ptr + m * attn_mask_stride, + cmask_ptr + m * cmask_stride, + select_nfltmax_at_0, + ncausal, + kv_len, + ov::element::f32, + ov::element::f32); } mlas_sgemm("N", "N", @@ -864,9 +669,7 @@ struct MHASingleToken { PlainTensor m_attn_w; PlainTensor m_temp; PlainTensor m_head_sum; -#ifdef OPENVINO_ARCH_X86_64 - std::shared_ptr m_gemv; -#endif + MHASingleToken() {} // Q, K, V is ready, do attention @@ -883,8 +686,6 @@ struct MHASingleToken { const PlainTensor& attention_mask, PlainTensor& output_emb, const PlainTensor& beams, - size_t max_context_len, - const PlainTensor& context_lens, bool has_out_transpose, bool auto_causal, float d_scale, @@ -893,81 +694,13 @@ struct MHASingleToken { auto B = query.size(0); auto H = query.size(1); auto q_len = query.size(2); - bool is_pagedattn = context_lens; size_t kv_len; - if (is_pagedattn) { - kv_len = max_context_len; - } else { - kv_len = present_key.size(2); - } + kv_len = present_key.size(2); - bool fastpath_valid = false; -#ifdef OPENVINO_ARCH_X86_64 - if (is_pagedattn) { - auto S = query.size(3); - size_t block_size = present_value.size(2); - fastpath_valid = mayiuse(amx_bf16) && (S % 32 == 0) && (block_size % 16 == 0) && (S <= 32 * 6) && present_key.get_precision() == ov::element::bf16; - if (fastpath_valid) { - m_attn_w.resize({B, H, q_len, (kv_len + block_size - 1) / block_size * block_size}); - if (!m_gemv) - m_gemv = std::make_shared(static_cast(S), static_cast(block_size)); - auto h_group_num = present_value.size(1); - size_t h_each_group_len = 1; - if (h_group_num != H) { - h_each_group_len = H / h_group_num; - } - auto kv_len_in_blocks = beams.m_dims[1]; - auto nthr = static_cast(parallel_get_max_threads()); - size_t real_len = 0; - for (size_t b = 0; b < B; b++) - real_len += static_cast(context_lens.ptr()[b]) / block_size; - if (real_len > nthr) { - parallel_for2d_dynamic(B, kv_len_in_blocks, [&](size_t b, size_t pk_in_blocks) { - auto context_len = static_cast(context_lens.ptr()[b]); - // kv_len must be valid - auto pk = pk_in_blocks * block_size; - if (pk < context_len) { - m_gemv->tile_config(); - auto block_number = beams.ptr(b)[pk_in_blocks]; - for (size_t h_group = 0; h_group < h_group_num; h_group++) { - for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { - (*m_gemv)(query.ptr(b, h, pq), present_key.ptr(block_number, h_group), - m_attn_w.ptr(b, h, pq) + pk); - } - } - } - m_gemv->tile_release(); - } - }); - } else { - parallel_for3d_dynamic(B, kv_len_in_blocks, h_group_num, [&](size_t b, size_t pk_in_blocks, size_t h_group) { - auto context_len = static_cast(context_lens.ptr()[b]); - // kv_len must be valid - auto pk = pk_in_blocks * block_size; - if (pk < context_len) { - m_gemv->tile_config(); - auto block_number = beams.ptr(b)[pk_in_blocks]; - - for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { - (*m_gemv)(query.ptr(b, h, pq), present_key.ptr(block_number, h_group), - m_attn_w.ptr(b, h, pq) + pk); - } - } - m_gemv->tile_release(); - } - }); - } - } - } -#endif - if (!fastpath_valid) { - // aligned to cache line (64bytes=16*sizeof(float)) to avoid false sharing - m_attn_w.resize({B, H, q_len, (kv_len + 15) / 16 * 16}); - } - mha_single_token(query, fastpath_valid ? PlainTensor() : present_key, present_value, alibi_mask, attention_mask, beams, max_context_len, - context_lens, output_emb, m_attn_w, m_temp, has_out_transpose, auto_causal, d_scale, k_scale_zp, v_scale_zp, m_head_sum); + // aligned to cache line (64bytes=16*sizeof(float)) to avoid false sharing + m_attn_w.resize({B, H, q_len, (kv_len + 15) / 16 * 16}); + mha_single_token(query, present_key, present_value, alibi_mask, attention_mask, beams, + output_emb, m_attn_w, m_temp, has_out_transpose, auto_causal, d_scale, k_scale_zp, v_scale_zp, m_head_sum); } }; @@ -995,110 +728,65 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt bool fuse_causal_attn = config.config.fuse_causal_attn; bool is_causal = config.config.is_causal; bool fuse_concat = config.config.fuse_concat; - bool is_pagedattn = config.is_pageattn; auto input_num = inputs.size(); - bool is_prompt = false; PlainTensor present_key, present_value; PlainTensor q_input; // f32[B, H, L1, S] PlainTensor k_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S] PlainTensor v_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S] PlainTensor beam_table; // i32[B, max_kvLen] - PlainTensor context_lens; PlainTensor attn_mask; PlainTensor output_emb(output); float scale_input = 0.0f; size_t B, L1, L0, S; - size_t sliding_window = 0; - size_t max_context_len = 0; q_input.reset(inputs[0]); k_input.reset(inputs[1]); v_input.reset(inputs[2]); present_key.reset(presentk_input); present_value.reset(presentv_input); - if (is_pagedattn) { - is_prompt = *inputs[ID_IS_PROMPT]->getDataAs() == 1; - max_context_len = static_cast(*inputs[ID_MAX_CONTEXT_LEN]->getDataAs()); - context_lens.reset(inputs[ID_CONTEXT_LENS]); - beam_table.reset(inputs[ID_BLOCK_TABLES]); - scale_input = *inputs[ID_SCALE]->getDataAs(); - // TODO: alibi and sliding window - // no attn mask, auto-generated casual mask - is_causal = true; - has_out_transpose = true; - - // q: [B, L1, H*S], kv: [B, L1, Hk*S] - // k_cache: [NUM_BLOCKS, Hk, 16, S] - // v_cache: [NUM_BLOCKS, Hk, 16, S] - // context_lens: [B] - // block_tables: [B, max_block_per_request] - B = k_input.size(0); - L1 = k_input.size(1); - auto Hk = present_key.size(1); - // The layout for per token per head for u8 kv cache: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - // The actual size needs to deduct scale and zeropoint. - S = present_value.size(3) - (present_value.m_dt == ov::element::Type_t::u8 ? sizeof(float) * 2 : 0); - auto H = q_input.size(2) / S; - // L0 in each batch may be different - L0 = 0; - - q_input.assert_dims({B, L1, H * S}); - if (!is_prompt) { - context_lens.assert_dims({B}); - beam_table.assert_dims({B, 0}, true); + if (beam_input) + beam_table.reset(beam_input); + if (input_num > 3) { + // attn_mask + if (inputs[3]->getDesc().getPrecision() == ov::element::u8) { + // bool->f32 + prepare_attn_mask(inputs[3]); + attn_mask = attn_buf; } else { - sliding_window = static_cast(*inputs[ID_SLIDING_WINDOW]->getDataAs()); + attn_mask.reset(inputs[3]); } - output_emb.assert_dims({B, L1, H * S}); - q_input = q_input.reshape({B, L1, H, S}).permute({0, 2, 1, 3}); - k_input = k_input.reshape({B, L1, Hk, S}).permute({0, 2, 1, 3}); - v_input = v_input.reshape({B, L1, Hk, S}).permute({0, 2, 1, 3}); - } else { - if (beam_input) - beam_table.reset(beam_input); - if (input_num > 3) { - // attn_mask - if (inputs[3]->getDesc().getPrecision() == ov::element::u8) { - // bool->f32 - prepare_attn_mask(inputs[3]); - attn_mask = attn_buf; - } else { - attn_mask.reset(inputs[3]); - } - // if has scale, attn_mask must be present - if (input_num > 4) { - scale_input = *inputs[4]->getDataAs(); - } + // if has scale, attn_mask must be present + if (input_num > 4) { + scale_input = *inputs[4]->getDataAs(); } + } - // q: [B, H, L1, S] - const auto & permute_axes = config.config.permute_axes; - if (!permute_axes.empty()) { - q_input = q_input.permute(permute_axes); - k_input = k_input.permute(permute_axes); - v_input = v_input.permute(permute_axes); - present_key = present_key.permute(permute_axes); - present_value = present_value.permute(permute_axes); - } - B = q_input.size(0); - L1 = q_input.size(2); - S = q_input.size(3); - L0 = present_key.size(2) - L1; - auto Hk = k_input.size(1); - - if (fuse_concat) { - k_input.assert_dims({B, Hk, L1, S}); - v_input.assert_dims({B, Hk, L1, S}); - } else { - k_input.assert_dims({B, Hk, L0 + L1, S}); - v_input.assert_dims({B, Hk, L0 + L1, S}); - } - present_key.assert_dims({B, Hk, L0 + L1, S}); - present_value.assert_dims({B, Hk, L0 + L1, S}); - if (beam_table) - beam_table.assert_dims({B, L0 + L1}); + // q: [B, H, L1, S] + const auto & permute_axes = config.config.permute_axes; + if (!permute_axes.empty()) { + q_input = q_input.permute(permute_axes); + k_input = k_input.permute(permute_axes); + v_input = v_input.permute(permute_axes); + present_key = present_key.permute(permute_axes); + present_value = present_value.permute(permute_axes); + } + B = q_input.size(0); + L1 = q_input.size(2); + S = q_input.size(3); + L0 = present_key.size(2) - L1; + auto Hk = k_input.size(1); + + if (fuse_concat) { + k_input.assert_dims({B, Hk, L1, S}); + v_input.assert_dims({B, Hk, L1, S}); + } else { + k_input.assert_dims({B, Hk, L0 + L1, S}); + v_input.assert_dims({B, Hk, L0 + L1, S}); } + present_key.assert_dims({B, Hk, L0 + L1, S}); + present_value.assert_dims({B, Hk, L0 + L1, S}); + if (beam_table) + beam_table.assert_dims({B, L0 + L1}); bool auto_causal; bool use_attn_mask; @@ -1130,15 +818,11 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt } // second token, or first token with pastkv fusing - bool use_one_token; - if (is_pagedattn) - use_one_token = !is_prompt; - else - use_one_token = L1 == 1 || (fuse_concat && L0 > 0); + bool use_one_token = L1 == 1 || (fuse_concat && L0 > 0); if (!use_one_token) { // multi-token version kernel(strm, q_input, k_input, v_input, {}, use_attn_mask ? attn_mask : PlainTensor(), - output_emb, has_out_transpose, auto_causal, scale_input, sliding_window); + output_emb, has_out_transpose, auto_causal, scale_input); } else { // 1-token version // for second token, using a special AVX2/AVX512 float path: @@ -1146,7 +830,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt // 2, using float will save the repack cost which typically is required for bf16/int8 opt // 3, using dot product can leverage the SIMD while easily adapt to indirect kv cache kernel_single_token(q_input, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor(), - output_emb, beam_table, max_context_len, context_lens, has_out_transpose, auto_causal, scale_input, k_scale_zp, v_scale_zp); + output_emb, beam_table, has_out_transpose, auto_causal, scale_input, k_scale_zp, v_scale_zp); } } }; @@ -1158,18 +842,12 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptrget_type_name() == std::string("PagedAttentionExtension")) { - m_is_pageattn = true; - m_config.is_pageattn = true; + const auto node = std::dynamic_pointer_cast(op); + if (node) { + m_config.config.is_causal = node->get_causal(); } else { - m_is_pageattn = false; - const auto node = std::dynamic_pointer_cast(op); - if (node) { - m_config.config.is_causal = node->get_causal(); - } else { - const auto node = std::dynamic_pointer_cast(op); - m_config.config = node->get_config(); - } + const auto node = std::dynamic_pointer_cast(op); + m_config.config = node->get_config(); } } @@ -1189,83 +867,49 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() { rtPrecision, getInputShapeAtPort(1))); config.inConfs[2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( rtPrecision, getInputShapeAtPort(2))); - if (m_is_pageattn) { - OPENVINO_ASSERT(getOriginalInputsNumber() == 13, "The input number of PagedAttention should be 13."); - // kvcache, float, [] - auto past_kv_input_mem_precision = getOriginalInputPrecisionAtPort(ID_KCACHE); - config.inConfs[ID_KCACHE].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - past_kv_input_mem_precision, getInputShapeAtPort(ID_KCACHE))); - config.inConfs[ID_VCACHE].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - past_kv_input_mem_precision, getInputShapeAtPort(ID_VCACHE))); - // is_prompt, bool, [] - config.inConfs[ID_IS_PROMPT].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - ov::element::u8, getInputShapeAtPort(ID_IS_PROMPT))); - // slot_mapping, int, [batch_size, max_context_len] - config.inConfs[ID_SLOT_MAPPING].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - ov::element::i32, getInputShapeAtPort(ID_SLOT_MAPPING))); - // max_context_len, int, [] - config.inConfs[ID_MAX_CONTEXT_LEN].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - ov::element::i32, getInputShapeAtPort(ID_MAX_CONTEXT_LEN))); - // context_lens, int, [batch_size] - config.inConfs[ID_CONTEXT_LENS].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - ov::element::i32, getInputShapeAtPort(ID_CONTEXT_LENS))); - // block_tables, int, [batch_size, max_block_per_request] - config.inConfs[ID_BLOCK_TABLES].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - ov::element::i32, getInputShapeAtPort(ID_BLOCK_TABLES))); - // scale, float, [] - config.inConfs[ID_SCALE].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - ov::element::f32, getInputShapeAtPort(ID_SCALE))); - // alibi_slopes, float, [?] or nullptr - config.inConfs[ID_ALIBI_SLOPES].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - ov::element::f32, getInputShapeAtPort(ID_ALIBI_SLOPES))); - // sliding_window, int, [] - config.inConfs[ID_SLIDING_WINDOW].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - ov::element::i32, getInputShapeAtPort(ID_SLIDING_WINDOW))); - } else { - auto nextPortIdx = 3; - if (orginSDPInputNumber > 3) { - // attn_mask - if (getOriginalInputPrecisionAtPort(nextPortIdx) == ov::element::u8) { - config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - ov::element::u8, getInputShapeAtPort(nextPortIdx))); - } else { - config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - rtPrecision, getInputShapeAtPort(nextPortIdx))); - } - nextPortIdx++; - } - if (orginSDPInputNumber > 4) { + auto nextPortIdx = 3; + if (orginSDPInputNumber > 3) { + // attn_mask + if (getOriginalInputPrecisionAtPort(nextPortIdx) == ov::element::u8) { + config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::u8, getInputShapeAtPort(nextPortIdx))); + } else { config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - ov::element::f32, getInputShapeAtPort(nextPortIdx))); + rtPrecision, getInputShapeAtPort(nextPortIdx))); } + nextPortIdx++; + } + if (orginSDPInputNumber > 4) { + config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::f32, getInputShapeAtPort(nextPortIdx))); + } - if (m_config.config.fuse_concat) { - // beam_idx - config.inConfs[orginSDPInputNumber + 0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - ov::element::i32, getInputShapeAtPort(orginSDPInputNumber + 0))); - - // Since the InputMemory nodes are simple proxy for the state memory as well as the init subgraph memory, - // it doesn't make sense to set the real KV cache precision, since we don't need any precision conversions - // provided by the common graph logic. We set precisions equal to the precisions of the state nodes to avoid - // reorder insertion in between MemoryInputSDPA and SDPA nodes. - - auto past_k_input_mem_precision = getParentEdgeAt(orginSDPInputNumber + 1)->getParent()->getOriginalOutputPrecisionAtPort(0); - // pastk - config.inConfs[orginSDPInputNumber + 1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - past_k_input_mem_precision, getInputShapeAtPort(orginSDPInputNumber + 1))); - - auto past_v_input_mem_precision = getParentEdgeAt(orginSDPInputNumber + 2)->getParent()->getOriginalOutputPrecisionAtPort(0); - // pastv - config.inConfs[orginSDPInputNumber + 2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - past_v_input_mem_precision, getInputShapeAtPort(orginSDPInputNumber + 2))); - - config.outConfs[1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - past_k_input_mem_precision, getOutputShapeAtPort(1))); - config.outConfs[1].inPlace(-1); - config.outConfs[2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - past_v_input_mem_precision, getOutputShapeAtPort(2))); - config.outConfs[2].inPlace(-1); - } + if (m_config.config.fuse_concat) { + // beam_idx + config.inConfs[orginSDPInputNumber + 0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::i32, getInputShapeAtPort(orginSDPInputNumber + 0))); + + // Since the InputMemory nodes are simple proxy for the state memory as well as the init subgraph memory, + // it doesn't make sense to set the real KV cache precision, since we don't need any precision conversions + // provided by the common graph logic. We set precisions equal to the precisions of the state nodes to avoid + // reorder insertion in between MemoryInputSDPA and SDPA nodes. + + auto past_k_input_mem_precision = getParentEdgeAt(orginSDPInputNumber + 1)->getParent()->getOriginalOutputPrecisionAtPort(0); + // pastk + config.inConfs[orginSDPInputNumber + 1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + past_k_input_mem_precision, getInputShapeAtPort(orginSDPInputNumber + 1))); + + auto past_v_input_mem_precision = getParentEdgeAt(orginSDPInputNumber + 2)->getParent()->getOriginalOutputPrecisionAtPort(0); + // pastv + config.inConfs[orginSDPInputNumber + 2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + past_v_input_mem_precision, getInputShapeAtPort(orginSDPInputNumber + 2))); + + config.outConfs[1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + past_k_input_mem_precision, getOutputShapeAtPort(1))); + config.outConfs[1].inPlace(-1); + config.outConfs[2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + past_v_input_mem_precision, getOutputShapeAtPort(2))); + config.outConfs[2].inPlace(-1); } config.outConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( @@ -1324,35 +968,25 @@ void ScaledDotProductAttention::execute(dnnl::stream strm) { } PlainTensor k_scale_zp, v_scale_zp; - if (m_is_pageattn) { - gatherConcatPastkvForPagedAttn(inputs); - - presentk_input = inputs[ID_KCACHE]; - presentv_input = inputs[ID_VCACHE]; + if (m_config.config.fuse_concat) { + CPU_NODE_ASSERT(m_k_state && m_v_state, "has null input states"); + // initialization will be also completed in this func + gatherConcatPastkv(inputs[1], inputs[2], getSrcMemoryAtPort(orginSDPInputNumber)); + + presentk_input = m_k_state->internal_state_mem(); + presentv_input = m_v_state->internal_state_mem(); + beam_input = m_k_state->hidden_state_mem(); + k_scale_zp = m_k_state->get_scale_zp(); + v_scale_zp = m_v_state->get_scale_zp(); } else { - if (m_config.config.fuse_concat) { - CPU_NODE_ASSERT(m_k_state && m_v_state, "has null input states"); - // initialization will be also completed in this func - gatherConcatPastkv(inputs[1], inputs[2], getSrcMemoryAtPort(orginSDPInputNumber)); - - presentk_input = m_k_state->internal_state_mem(); - presentv_input = m_v_state->internal_state_mem(); - beam_input = m_k_state->hidden_state_mem(); - k_scale_zp = m_k_state->get_scale_zp(); - v_scale_zp = m_v_state->get_scale_zp(); - } else { - presentk_input = inputs[1]; - presentv_input = inputs[2]; - } + presentk_input = inputs[1]; + presentv_input = inputs[2]; } m_executor->execute(strm, m_config, inputs, output, presentk_input, presentv_input, beam_input, k_scale_zp, v_scale_zp); } bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { - if (op->get_type_name() == std::string("PagedAttentionExtension")) { - return true; - } if (!std::dynamic_pointer_cast(op) && !std::dynamic_pointer_cast(op)) { errorMessage = "Only ScaledDotProductAttention or ScaledDotProductAttentionWithKVCache operation are supported"; @@ -1575,36 +1209,6 @@ void ScaledDotProductAttention::resetBeamTablePastkv(const MemoryPtr& mem_cur_k, } } -void ScaledDotProductAttention::gatherConcatPastkvForPagedAttn(const std::vector& inputs) { - PlainTensor k, v, k_cache, v_cache, slot_mapping; - - k.reset(inputs[ID_K]); // [B, L1, H * S] - v.reset(inputs[ID_V]); - k_cache.reset(inputs[ID_KCACHE]); // [NUM_BLOCKS, H, 16, S] - v_cache.reset(inputs[ID_VCACHE]); // [NUM_BLOCKS, H, 16, S] - slot_mapping.reset(inputs[ID_SLOT_MAPPING]); // [B, max_context_len] - - auto B = k.size(0); - auto L1 = k.size(1); - auto H = k_cache.size(1); - auto S = v_cache.size(3) - (k_cache.m_dt == ov::element::Type_t::u8 ? 8 : 0); - - k.assert_dims({B, L1, H * S}); - v.assert_dims({B, L1, H * S}); - slot_mapping.assert_dims({B, 0}, true); - k = k.reshape({B, L1, H, S}).permute({0, 2, 1, 3}); - v = v.reshape({B, L1, H, S}).permute({0, 2, 1, 3}); - if (k_cache.m_dt == ov::element::Type_t::u8) { - k_cache.assert_dims({0, H, 0, S + 8}, true); - v_cache.assert_dims({k_cache.m_dims[0], H, k_cache.m_dims[2], S + 8}); - paged_attn_quantkv(k, v, k_cache, v_cache, slot_mapping); - } else { - k_cache.assert_dims({0, H, 0, S}, true); - v_cache.assert_dims({k_cache.m_dims[0], H, k_cache.m_dims[2], S}); - paged_attn_memcpy(k, v, k_cache, v_cache, slot_mapping); - } -} - void ScaledDotProductAttention::gatherConcatPastkv(const MemoryPtr& mem_cur_k, const MemoryPtr& mem_cur_v, const MemoryPtr& mem_beam_idx) { PlainTensor cur_k; cur_k.reset(mem_cur_k); diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.h b/src/plugins/intel_cpu/src/nodes/scaled_attn.h index b94d5a030c4b49..c2de1f0d86ac4d 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.h +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.h @@ -52,7 +52,6 @@ class ScaledDotProductAttention : public Node { private: void gatherConcatPastkv(const MemoryPtr& mem_cur_k, const MemoryPtr& mem_cur_v, const MemoryPtr& mem_beam_idx); - void gatherConcatPastkvForPagedAttn(const std::vector& inputs); void updateBeamTable(const MemoryPtr& mem_beam_idx, size_t new_q_len); void updatePastkv(const MemoryPtr& mem_cur_k, const MemoryPtr& mem_cur_v); ov::element::Type getRuntimePrecision() const override; @@ -60,7 +59,6 @@ class ScaledDotProductAttention : public Node { struct Config { ScaledDotProductAttentionWithKVCache::Config config; - bool is_pageattn = false; }; struct Executor { @@ -70,7 +68,6 @@ class ScaledDotProductAttention : public Node { virtual ~Executor() = default; }; - bool m_is_pageattn; Config m_config; std::shared_ptr m_executor; template struct AttentionExecutor; @@ -82,21 +79,6 @@ class ScaledDotProductAttention : public Node { // (0, 1, 2, 3) for BHLS // (2, 0, 1, 3) for LBHS std::vector m_kvstate_layout = {2, 0, 1, 3}; - - // PagedAttention input index - static const size_t ID_Q = 0; - static const size_t ID_K = 1; - static const size_t ID_V = 2; - static const size_t ID_KCACHE = 3; - static const size_t ID_VCACHE = 4; - static const size_t ID_IS_PROMPT = 5; - static const size_t ID_SLOT_MAPPING = 6; - static const size_t ID_MAX_CONTEXT_LEN = 7; - static const size_t ID_CONTEXT_LENS = 8; - static const size_t ID_BLOCK_TABLES = 9; - static const size_t ID_SCALE = 10; - static const size_t ID_ALIBI_SLOPES = 11; - static const size_t ID_SLIDING_WINDOW = 12; }; } // namespace node diff --git a/src/plugins/intel_cpu/src/nodes_factory.cpp b/src/plugins/intel_cpu/src/nodes_factory.cpp index 2146601b1c4cdd..f52b6d5f569692 100644 --- a/src/plugins/intel_cpu/src/nodes_factory.cpp +++ b/src/plugins/intel_cpu/src/nodes_factory.cpp @@ -61,6 +61,7 @@ #include "nodes/normalize.h" #include "nodes/one_hot.h" #include "nodes/pad.h" +#include "nodes/paged_attn.h" #include "nodes/pooling.h" #include "nodes/priorbox.h" #include "nodes/priorbox_clustered.h" @@ -205,6 +206,7 @@ Node::NodesFactory::NodesFactory() : Factory("NodesFactory") { INTEL_CPU_NODE(Interaction, Type::Interaction); INTEL_CPU_NODE(MHA, Type::MHA); INTEL_CPU_NODE(ScaledDotProductAttention, Type::ScaledDotProductAttention); + INTEL_CPU_NODE(PagedAttention, Type::PagedAttention); INTEL_CPU_NODE(Snippet, Type::Subgraph); #endif } diff --git a/src/plugins/intel_cpu/src/shape_inference/custom/paged_attn.cpp b/src/plugins/intel_cpu/src/shape_inference/custom/paged_attn.cpp new file mode 100644 index 00000000000000..52043464bb28c2 --- /dev/null +++ b/src/plugins/intel_cpu/src/shape_inference/custom/paged_attn.cpp @@ -0,0 +1,38 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "paged_attn.hpp" + +#include "shape_inference/shape_inference_cpu.hpp" +#include "shape_inference/shape_inference_ngraph.hpp" +#include "transformations/cpu_opset/common/op/sdpa.hpp" +#include "utils.hpp" + +namespace ov { +namespace intel_cpu { +namespace node { + +class PAShapeInfer : public ShapeInferEmptyPads { +public: + PAShapeInfer() {} + + IShapeInfer::Result infer(const std::vector>& input_shapes, + const std::unordered_map& data_dependency) override { + const auto& query_dims = input_shapes.front().get(); + + return {{query_dims}, ShapeInferStatus::success}; + } + + port_mask_t get_port_mask() const override { + return EMPTY_PORT_MASK; + } +}; + +ShapeInferPtr PAShapeInferFactory::makeShapeInfer() const { + return std::make_shared(); +} + +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/shape_inference/custom/paged_attn.hpp b/src/plugins/intel_cpu/src/shape_inference/custom/paged_attn.hpp new file mode 100644 index 00000000000000..759f2d1dd8651b --- /dev/null +++ b/src/plugins/intel_cpu/src/shape_inference/custom/paged_attn.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "shape_inference/shape_inference_cpu.hpp" + +#pragma once +namespace ov { +namespace intel_cpu { +namespace node { + +class PAShapeInferFactory : public ShapeInferFactory { +public: + PAShapeInferFactory(std::shared_ptr op) : m_op(op) {} + ShapeInferPtr makeShapeInfer() const override; + +private: + std::shared_ptr m_op; +}; +} // namespace node +} // namespace intel_cpu +} // namespace ov