Skip to content

Commit

Permalink
[Runtime] Support KV cache with RoPE extension factor array
Browse files Browse the repository at this point in the history
This PR enhances the KV cache with the RoPE extensio factor support.
With this PR, the KV cache can support models like Phi3.5 which comes
with the extension factor.
  • Loading branch information
MasterJH5574 committed Aug 23, 2024
1 parent 2ddc0fa commit d132897
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
19 changes: 16 additions & 3 deletions src/runtime/relax_vm/kv_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,23 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions")
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DebugGetKV);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv")
.set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
double attn_score_scaling_factor, NDArray qkv_data, NDArray o_data) {
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK(args.size() == 5 || args.size() == 6)
<< "KVState AttentionWithFusedQKV only accepts 5 or 6 arguments";
AttentionKVCache kv_cache = args[0];
int64_t layer_id = args[1];
double attn_score_scaling_factor = args[2];
NDArray qkv_data = args[3];
NDArray o_data;
Optional<NDArray> ext_factors = NullOpt;
if (args.size() == 5) {
o_data = args[4];
} else {
ext_factors = args[4].operator tvm::runtime::NDArray();
o_data = args[5];
}
kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt, std::move(o_data),
attn_score_scaling_factor);
std::move(ext_factors), attn_score_scaling_factor);
});

// RNN State methods
Expand Down
5 changes: 4 additions & 1 deletion src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,13 @@ class AttentionKVCacheObj : public KVStateObj {
* `(total_length, num_qo_heads + 2 * num_kv_heads, head_dim)`.
* \param mask The input mask data, in layout `(total_sqr_length)`.
* \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`.
* \param rope_ext_factors The RoPE extension factor array in shape `(head_dim // 2,)`.
* \param attn_score_scaling_factor The additional attention scaling factor.
* \sa AttentionKVCache::Attention
*/
virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional<NDArray> mask,
NDArray o_data, double attn_score_scaling_factor) = 0;
NDArray o_data, Optional<NDArray> rope_ext_factors,
double attn_score_scaling_factor) = 0;

/************** Positions **************/

Expand Down
14 changes: 11 additions & 3 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1685,7 +1685,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}

void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional<NDArray> mask,
NDArray o_data, double attn_score_scaling_factor) final {
NDArray o_data, Optional<NDArray> rope_ext_factors,
double attn_score_scaling_factor) final {
// Part 1. Shape and dtype check.
int64_t local_layer_id = layer_id - layer_id_begin_offset_;
CHECK_GE(local_layer_id, 0);
Expand Down Expand Up @@ -1726,8 +1727,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
NDArray v_data = temp_attn_v_device_.CreateView({total_seq_length, num_kv_heads_, head_dim_},
qkv_data->dtype);
// Part 2. Split fused qkv and apply rotary embedding to q/k data.
f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data,
static_cast<int>(rope_mode_ == RoPEMode::kNormal));
if (!rope_ext_factors.defined()) {
f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data,
static_cast<int>(rope_mode_ == RoPEMode::kNormal));
} else {
CHECK(rope_mode_ == RoPEMode::kNormal)
<< "The RoPE mode must be normal to support RoPE extension factors.";
f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data,
rope_ext_factors.value());
}

// Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set.
if (append_before_attn_) {
Expand Down

0 comments on commit d132897

Please sign in to comment.