Skip to content

Commit

Permalink
[CPU] PagedAttention supports dynamic-split fuse (openvinotoolkit#24107)
Browse files Browse the repository at this point in the history
### Details:
 - *Merge first token and second token inference into one parallel loop*
- *~~Additional optimization: pre-transpose k-cache, pre-pack v-cache if
needed~~*
- *Additional optimization for first token: save q * k' upper triangle
matrix computation and (q * k') * v lower triangle matrix computation*
- *C++ pipeline can enable it:
ilya-lavrenov/openvino.genai#9
 - *TODO(in another PR):*
   - alibi support
   - performance tuning
   - testcase

### Tickets:
 - *[138673](https://jira.devtools.intel.com/browse/CVS-138673)*
  • Loading branch information
luo-cheng2021 authored May 15, 2024
1 parent 6950460 commit 04a7ecf
Show file tree
Hide file tree
Showing 24 changed files with 2,815 additions and 1,348 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ function(cross_compiled_file TARGET)
## that is arch ID
set(_arch ${_it})
if(_arch MATCHES ${_CURRENT_ARCH_FILTER})
list(APPEND _CUR_ARCH_SET ${_arch})
# make non/less-optimized version coming first
list(INSERT _CUR_ARCH_SET 0 ${_arch})
list(APPEND _FULL_ARCH_SET ${_arch})
endif()
else()
Expand Down
7 changes: 7 additions & 0 deletions src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,13 @@ cross_compiled_file(${TARGET_NAME}
NAME mha_single_token
NAMESPACE ov::Extensions::Cpu::XARCH
)
cross_compiled_file(${TARGET_NAME}
ARCH AVX512F AVX2 ANY
src/nodes/kernels/scaled_attn/executor_pa.cpp
API src/nodes/kernels/scaled_attn/executor_pa.hpp
NAME make_pa_executor
NAMESPACE ov::Extensions::Cpu::XARCH
)
cross_compiled_file(${TARGET_NAME}
ARCH AVX512F AVX2 ANY
src/nodes/kernels/scaled_attn/attn_memcpy.cpp
Expand Down
3 changes: 2 additions & 1 deletion src/plugins/intel_cpu/src/cpu_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
{"Ngram", Type::Ngram},
{"ScaledDotProductAttention", Type::ScaledDotProductAttention},
{"ScaledDotProductAttentionWithKVCache", Type::ScaledDotProductAttention},
{"PagedAttentionExtension", Type::ScaledDotProductAttention},
{"PagedAttentionExtension", Type::PagedAttention},
{"RoPE", Type::RoPE},
{"GatherCompressed", Type::Gather},
{"CausalMaskPreprocess", Type::CausalMaskPreprocess},
Expand Down Expand Up @@ -358,6 +358,7 @@ std::string NameFromType(const Type type) {
CASE(Unique);
CASE(Ngram);
CASE(ScaledDotProductAttention);
CASE(PagedAttention);
CASE(RoPE);
CASE(CausalMaskPreprocess);
CASE(Unknown);
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/cpu_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ enum class Type {
Unique,
Ngram,
ScaledDotProductAttention,
PagedAttention,
RoPE,
CausalMaskPreprocess,
};
Expand Down
3 changes: 1 addition & 2 deletions src/plugins/intel_cpu/src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1811,8 +1811,7 @@ void Graph::EnforceInferencePrecision() {
return true;

// kvcache of PagedAttention should be written directly
if (node->getType() == Type::ScaledDotProductAttention && node->getOriginalInputsNumber() == 13 &&
(inPort == 3 || inPort == 4))
if (node->getType() == Type::PagedAttention && (inPort == 3 || inPort == 4))
return true;
const auto &parent = node->getParentEdgeAt(inPort)->getParent();
/* Skip BF16 enforcement for nodes after Constant Inputs for maintaining precision for fusing.
Expand Down
33 changes: 2 additions & 31 deletions src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "openvino/core/parallel.hpp"
#include "common.hpp"
#include "attn_quant.hpp"
#include "attn_quant_kernel.hpp"

namespace ov {
namespace Extensions {
Expand Down Expand Up @@ -259,37 +260,7 @@ void attn_quant_u8(const float* src, uint8_t* dst, size_t n, float& scale, float
}

void attn_dequant_u8(const uint8_t* src, float* dst, size_t n, float scale, float zp) {
size_t i = 0;
// loadu_si128/epi64 does not support const qualifier
uint8_t* src_nc = const_cast<uint8_t*>(src);
#if defined(HAVE_AVX512F)
auto v_zp = _mm512_set1_ps(zp);
auto v_scale = _mm512_set1_ps(scale);
for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) {
auto v0_128 = _mm_loadu_si128(reinterpret_cast<__m128i*>(src_nc + i));
auto v0_512 = _mm512_cvtepu8_epi32(v0_128);
auto v0_value = _mm512_cvtepi32_ps(v0_512);
v0_value = _mm512_sub_ps(v0_value, v_zp);
auto v0_out = _mm512_mul_ps(v0_value, v_scale);
mm512_uni_storeu_ps(dst + i, v0_out);
}
#elif defined(HAVE_AVX2)
auto v_zp = _mm256_set1_ps(zp);
auto v_scale = _mm256_set1_ps(scale);
for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) {
auto v0_128 = _mm_loadl_epi64(reinterpret_cast<__m128i*>(src_nc + i));
auto v0_256 = _mm256_cvtepu8_epi32(v0_128);
auto v0_value = _mm256_cvtepi32_ps(v0_256);
v0_value = _mm256_sub_ps(v0_value, v_zp);
auto v0_out = _mm256_mul_ps(v0_value, v_scale);
mm256_uni_storeu_ps(dst + i, v0_out);
}
#endif
for (; i < n; ++i) {
float tmp = src_nc[i];
tmp = (tmp - zp) * scale;
dst[i] = tmp;
}
attn_dequant_u8_kernel(src, dst, n, scale, zp);
}

} // namespace XARCH
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once

#include <array>
#include <cstddef>
#include <cstdint>
#include <vector>
#include "openvino/core/type/element_type.hpp"
#include "utils/plain_tensor.hpp"

namespace ov {
namespace Extensions {
namespace Cpu {
namespace XARCH {

template<typename TDST>
void attn_dequant_u8_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) {
size_t i = 0;
// loadu_si128/epi64 does not support const qualifier
uint8_t* src_nc = const_cast<uint8_t*>(src);
#if defined(HAVE_AVX512F)
auto v_zp = _mm512_set1_ps(zp);
auto v_scale = _mm512_set1_ps(scale);
for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) {
auto v0_128 = _mm_loadu_si128(reinterpret_cast<__m128i*>(src_nc + i));
auto v0_512 = _mm512_cvtepu8_epi32(v0_128);
auto v0_value = _mm512_cvtepi32_ps(v0_512);
v0_value = _mm512_sub_ps(v0_value, v_zp);
auto v0_out = _mm512_mul_ps(v0_value, v_scale);
mm512_uni_storeu_ps(dst + i, v0_out);
}
#elif defined(HAVE_AVX2)
auto v_zp = _mm256_set1_ps(zp);
auto v_scale = _mm256_set1_ps(scale);
for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) {
auto v0_128 = _mm_loadl_epi64(reinterpret_cast<__m128i*>(src_nc + i));
auto v0_256 = _mm256_cvtepu8_epi32(v0_128);
auto v0_value = _mm256_cvtepi32_ps(v0_256);
v0_value = _mm256_sub_ps(v0_value, v_zp);
auto v0_out = _mm256_mul_ps(v0_value, v_scale);
mm256_uni_storeu_ps(dst + i, v0_out);
}
#endif
for (; i < n; ++i) {
float tmp = src_nc[i];
tmp = (tmp - zp) * scale;
dst[i] = tmp;
}
}

} // namespace XARCH
} // namespace Cpu
} // namespace Extensions
} // namespace ov
Loading

0 comments on commit 04a7ecf

Please sign in to comment.