: public typed_primitive_inst_base prefill_network;
diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp
index 787fd184f75b6a..c761aaf63799cd 100644
--- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp
+++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp
@@ -48,14 +48,38 @@ layout paged_attention_inst::calc_output_layout(const paged_attention_node& /*no
template
std::vector paged_attention_inst::calc_output_layouts(paged_attention_node const& /*node*/, kernel_impl_params const& impl_param) {
- auto out_layout = impl_param.get_input_layout(0);
+ auto data_layout = impl_param.get_input_layout(0);
const auto& key_cache_ps = impl_param.get_input_layout(3).get_partial_shape();
bool valid_block_size = key_cache_ps[3].is_dynamic() || key_cache_ps[3].get_length() == paged_attention::block_size;
OPENVINO_ASSERT(valid_block_size, "[GPU] Incorrect block size for Paged Attention operation. "
"Expected ", paged_attention::block_size, ", but got ", key_cache_ps[3].get_length());
- return {out_layout};
+ std::vector output_layouts{ data_layout };
+
+ const auto& desc = impl_param.typed_desc();
+ if (desc->has_scores_output()) {
+ const auto past_lens_idx = 5;
+ const auto output_dt = data_layout.data_type;
+ if (impl_param.get_input_layout(past_lens_idx).is_static()) {
+ const auto& memory_deps = impl_param.memory_deps;
+ const auto past_lens_mem = memory_deps.at(past_lens_idx);
+ mem_lock past_lens_mem_lock(past_lens_mem, *impl_param.strm);
+
+ long int total_size = 0;
+ for (size_t i = 0; i < past_lens_mem_lock.size(); i++) {
+ total_size += past_lens_mem_lock[i];
+ }
+
+ total_size += static_cast(impl_param.get_input_layout(0).get_shape()[0]);
+
+ output_layouts.push_back(layout{ov::PartialShape{total_size}, output_dt, format::bfyx});
+ } else {
+ output_layouts.push_back(layout{ov::PartialShape::dynamic(1), output_dt, format::bfyx});
+ }
+ }
+
+ return output_layouts;
}
template std::vector
@@ -81,45 +105,79 @@ std::string paged_attention_inst::to_string(const paged_attention_node& node) {
}
void paged_attention_inst::on_execute() {
- auto stage = get_paged_attention_stage(*_impl_params);
+ const auto& desc = _impl_params->typed_desc();
+ const bool has_scores_output = desc->has_scores_output();
+ const auto stage = get_paged_attention_stage(*_impl_params);
- if (stage == PagedAttentionStage::UNKNOWN ||
- stage == PagedAttentionStage::GENERATE)
+ if ((stage == PagedAttentionStage::UNKNOWN) ||
+ (stage == PagedAttentionStage::GENERATE && !has_scores_output))
return;
+ auto& stream = get_network().get_stream();
+ const auto past_lens_mem = past_lens_memory_ptr();
+ const auto subsequence_begins_mem = subsequence_begins_memory_ptr();
+ mem_lock past_lens_mem_lock(past_lens_mem, stream);
+ mem_lock subsequence_begins_mem_lock(subsequence_begins_mem, stream);
+ std::unique_ptr> subsequence_offsets_lock = nullptr;
+
+ if (has_scores_output) {
+ const size_t subsequence_offsets_idx = 4;
+
+ OPENVINO_ASSERT(_intermediates_memory.size() > subsequence_offsets_idx,
+ "[GPU] Unexpected number of intermediates buffers for Paged Attention for scores output calculation");
+
+ auto subsequence_offsets_mem = _intermediates_memory[subsequence_offsets_idx];
+ subsequence_offsets_lock.reset(new mem_lock(subsequence_offsets_mem, stream));
+ }
+
+ if (stage == PagedAttentionStage::GENERATE) {
+ // For the generate stage it's not necessary to configure any other intermediate
+ // buffers. Simply calculate the offsets and exit
+ size_t subsequence_offsets_acc = 0;
+ for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) {
+ const auto past_len = past_lens_mem_lock[i];
+ const auto seq_start = subsequence_begins_mem_lock[i];
+ const auto seq_end = subsequence_begins_mem_lock[i + 1];
+ const auto seq_length = seq_end - seq_start;
+
+ if (subsequence_offsets_lock) {
+ subsequence_offsets_lock->operator[](i) = static_cast(subsequence_offsets_acc);
+ subsequence_offsets_acc += seq_length + past_len;
+ }
+ }
+
+ return;
+ }
+
OPENVINO_ASSERT(_intermediates_memory.size() >= 3, "Unexpected number of intermediates buffers for Paged Attention at prefill stage");
const auto blocks_indexes_start_idx = 0;
const auto blocks_indexes_end_idx = 1;
const auto blocked_gws_subseq_mapping_idx = 2;
- const auto past_lens_mem = past_lens_memory_ptr();
- auto subsequence_begins_mem = subsequence_begins_memory_ptr();
auto blocks_indexes_start_mem = _intermediates_memory[blocks_indexes_start_idx];
auto blocks_indexes_end_mem = _intermediates_memory[blocks_indexes_end_idx];
auto blocked_gws_subseq_mapping_mem = _intermediates_memory[blocked_gws_subseq_mapping_idx];
OPENVINO_ASSERT(subsequence_begins_mem->get_layout().data_type == data_types::i32);
- auto& stream = get_network().get_stream();
- mem_lock past_lens_mem_lock(past_lens_mem, stream);
- mem_lock subsequence_begins_mem_lock(subsequence_begins_mem, stream);
mem_lock blocks_indexes_start_lock(blocks_indexes_start_mem, stream);
mem_lock blocks_indexes_end_lock(blocks_indexes_end_mem, stream);
mem_lock blocked_gws_subseq_mapping_mem_lock(blocked_gws_subseq_mapping_mem, stream);
std::unique_ptr> sequential_gws_subseq_mapping_lock = nullptr;
if (stage == PagedAttentionStage::MIXED) {
- const auto sequential_gws_subseq_mapping_idx = 6;
+ const size_t sequential_gws_subseq_mapping_idx = has_scores_output ? 8 : 6;
OPENVINO_ASSERT(_intermediates_memory.size() > sequential_gws_subseq_mapping_idx,
- "Unexpected number of intermediates buffers for Paged Attention for mixed stage");
+ "[GPU] Unexpected number of intermediates buffers for Paged Attention for mixed stage");
auto sequential_gws_subseq_mapping_mem = _intermediates_memory[sequential_gws_subseq_mapping_idx];
sequential_gws_subseq_mapping_lock.reset(new mem_lock(sequential_gws_subseq_mapping_mem, stream));
}
size_t index = 0;
+ size_t subsequence_offsets_acc = 0;
const auto target_seq_len_block_size = 16; // TODO: Get block size from the impl
for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) {
const auto past_len = past_lens_mem_lock[i];
@@ -159,6 +217,11 @@ void paged_attention_inst::on_execute() {
sequential_gws_subseq_mapping_lock->operator[](idx) = static_cast(i);
}
}
+
+ if (subsequence_offsets_lock) {
+ subsequence_offsets_lock->operator[](i) = static_cast(subsequence_offsets_acc);
+ subsequence_offsets_acc += seq_length + past_len;
+ }
}
}
diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl
index 00c43829d02ea7..7e960afa4b87d3 100644
--- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl
+++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl
@@ -44,6 +44,10 @@ KERNEL(pa_sdpa_opt)(
const __global ALIBI_INPUT_TYPE* alibi_slopes,
#endif
__global OUTPUT_TYPE* output,
+#if PAGED_ATTENTION_SCORES_OUTPUT
+ __global SOFTMAX_ACCUMULATOR_TYPE* softmax_results,
+ const __global int* subsequence_offsets,
+#endif
__global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
__global SOFTMAX_ACCUMULATOR_TYPE* max_logits,
__global OUTPUT_TYPE* tmp_out
@@ -276,6 +280,28 @@ KERNEL(pa_sdpa_opt)(
const uint max_logits_offset = exp_sums_offset;
max_logits[max_logits_offset] = qk_max;
}
+
+#if PAGED_ATTENTION_SCORES_OUTPUT
+#if MULTI_TOKENS_PROCESSING
+ const uint subsequence_idx = gws_subseq_mapping[seq_idx];
+ const uint subsequence_start_pos = subsequence_begins[subsequence_idx];
+ const uint subsequence_end_pos = subsequence_begins[subsequence_idx + 1];
+ const bool save_softmax_results = seq_idx == subsequence_end_pos - 1;
+#else
+ const uint subsequence_idx = seq_idx;
+ const bool save_softmax_results = true;
+#endif // MULTI_TOKENS_PROCESSING
+ // PagedAttention is supposed to save only last "row" of the QK matrix multiplication,
+ // so save SEQ_LEN_PARTITION_SIZE elements for each partition
+ if (save_softmax_results) {
+ const uint output_offset = subsequence_idx * HEADS_NUM * total_partitions_num * SEQ_LEN_PARTITION_SIZE +
+ head_num_idx * total_partitions_num * SEQ_LEN_PARTITION_SIZE +
+ partition_idx * SEQ_LEN_PARTITION_SIZE;
+ for (uint i = sgid * SUBGROUP_SIZE + sglid; i < SEQ_LEN_PARTITION_SIZE; i += SUBGROUPS_PER_WG * SUBGROUP_SIZE) {
+ softmax_results[output_offset + i] = slm_qk_vals[i];
+ }
+ }
+#endif // PAGED_ATTENTION_SCORES_OUTPUT
}
}
@@ -370,6 +396,10 @@ KERNEL(pa_sdpa_finalization_stage)(
const __global INPUT6_TYPE* subsequence_begins,
#endif
__global OUTPUT_TYPE* output,
+#if PAGED_ATTENTION_SCORES_OUTPUT
+ __global SOFTMAX_ACCUMULATOR_TYPE* softmax_results,
+ const __global int* subsequence_offsets,
+#endif
const __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
const __global SOFTMAX_ACCUMULATOR_TYPE* max_logits,
const __global OUTPUT_TYPE* tmp_out,
@@ -500,3 +530,155 @@ KERNEL(pa_sdpa_finalization_stage)(
}
#endif
+
+#ifdef SDPA_STAGE_2
+#define MAX_PARTITIONS_NUM 128
+
+REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
+KERNEL(pa_sdpa_scores_calculation)(
+ const __global INPUT3_TYPE* past_lens,
+ const __global INPUT6_TYPE* subsequence_begins,
+ __global OUTPUT1_TYPE* scores_output,
+ const __global SOFTMAX_ACCUMULATOR_TYPE* softmax_output,
+ const __global int* subsequence_offsets,
+ const __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
+ const __global SOFTMAX_ACCUMULATOR_TYPE* max_logits,
+ const __global OUTPUT_TYPE* tmp_out,
+ const uint is_mixed_mode) {
+ const uint subsequence_idx = get_global_id(2);
+ const uint partition_global_idx = get_global_id(0);
+ const uint local_id = get_local_id(0);
+ const uint partition_idx = get_group_id(0);
+ const uint partition_size = get_local_size(0);
+ const uint max_seq_len = get_global_size(0);
+ const uint partitions_num = get_num_groups(0);
+ const uint sgid = get_sub_group_id();
+ const uint sgid_num = get_num_sub_groups();
+ const uint sglid = get_sub_group_local_id();
+
+ const int subsequence_begin = subsequence_begins[subsequence_idx];
+ const int subsequence_end = subsequence_begins[subsequence_idx + 1];
+ const uint seq_len = (subsequence_end - subsequence_begin) + past_lens[subsequence_idx];
+
+ const uint num_of_partitions = CEIL_DIV(seq_len, partition_size);
+
+ if (partition_idx >= num_of_partitions)
+ return;
+
+ __local SOFTMAX_ACCUMULATOR_TYPE slm_exp_sums[HEADS_NUM];
+ __local SOFTMAX_ACCUMULATOR_TYPE slm_global_exp_sum[HEADS_NUM];
+
+ SOFTMAX_ACCUMULATOR_TYPE total_score = SOFTMAX_ACCUMULATOR_VAL_ZERO;
+ if (seq_len <= partition_size) {
+ // If seq_len is less than the partition size, just reduce the results over the heads
+ for (uint head_idx = 0; head_idx < HEADS_NUM; head_idx++) {
+ const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx;
+ SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output[input_offset];
+ total_score += softmax_value;
+ }
+ } else if (seq_len <= partition_size * MAX_PARTITIONS_NUM) {
+ // Optimized version for longer prompts (up to partition_size * MAX_PARTITIONS_NUM, ~64K tokens)
+
+ // Depending on the previous kernel exp_sums and max_logits might have different structure:
+ // For ordinary 1st and 2nd token kernels, there is only a single entry per subsequence.
+ // However, for mixed mode execution, exp_sums and max_logits include information for all
+ // tokens of each subsequence, but only the last one is needed for score calculation.
+ const uint subsequence_pos = is_mixed_mode ? subsequence_end - 1 : subsequence_idx;
+
+ for (uint head_idx = sgid; head_idx < HEADS_NUM; head_idx += sgid_num) {
+ SOFTMAX_ACCUMULATOR_TYPE max_logit[MAX_PARTITIONS_NUM / SUBGROUP_SIZE];
+ SOFTMAX_ACCUMULATOR_TYPE exp_sum[MAX_PARTITIONS_NUM / SUBGROUP_SIZE];
+
+ const uint exp_sums_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num;
+ for (int i = 0; i < partitions_num / SUBGROUP_SIZE; i++) {
+ max_logit[i] = max_logits[exp_sums_offset + i * SUBGROUP_SIZE + sglid];
+ exp_sum[i] = exp_sums[exp_sums_offset + i * SUBGROUP_SIZE + sglid];
+ }
+
+ const uint partitions_leftovers = partitions_num % SUBGROUP_SIZE;
+ if (partitions_leftovers != 0) {
+ const uint idx = partitions_num / SUBGROUP_SIZE;
+ max_logit[idx] = sglid >= partitions_leftovers ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits[exp_sums_offset + idx * SUBGROUP_SIZE + sglid];
+ exp_sum[idx] = sglid >= partitions_leftovers ? SOFTMAX_ACCUMULATOR_VAL_ZERO : exp_sums[exp_sums_offset + idx * SUBGROUP_SIZE + sglid];
+ }
+
+ SOFTMAX_ACCUMULATOR_TYPE global_max_logit = max_logit[0];
+ for (uint i = 1; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) {
+ global_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(global_max_logit, max_logit[i]);
+ }
+
+ global_max_logit = sub_group_reduce_max(global_max_logit);
+
+ SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
+ for (uint i = 0; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) {
+ SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = exp_sum[i] * native_exp(max_logit[i] - global_max_logit);
+ // slm_exp_sums[head_idx][i * SUBGROUP_SIZE + sglid] = adjusted_exp_sum;
+ if (i * SUBGROUP_SIZE + sglid == partition_idx)
+ slm_exp_sums[head_idx] = adjusted_exp_sum;
+ global_exp_sum += adjusted_exp_sum;
+ }
+
+ global_exp_sum = sub_group_reduce_add(global_exp_sum);
+
+ slm_global_exp_sum[head_idx] = global_exp_sum;
+ }
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ for (uint head_idx = 0; head_idx < HEADS_NUM; head_idx++) {
+ SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = slm_exp_sums[head_idx];
+ SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = slm_global_exp_sum[head_idx];
+
+ const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx;
+ SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output[input_offset];
+
+ softmax_value = softmax_value * adjusted_exp_sum / global_exp_sum;
+ total_score += softmax_value;
+ }
+ } else {
+ // Non optimized fallback version
+ const uint subsequence_pos = is_mixed_mode ? subsequence_end - 1 : subsequence_idx;
+ for (uint head_idx = 0; head_idx < HEADS_NUM; head_idx++) {
+ SOFTMAX_ACCUMULATOR_TYPE global_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN;
+ const uint max_logits_base_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num;
+ for (uint i = 0; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) {
+ const uint partition_offset = i * SUBGROUP_SIZE + sglid;
+ SOFTMAX_ACCUMULATOR_TYPE max_logit = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits[max_logits_base_offset + partition_offset];
+ global_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(global_max_logit, max_logit);
+ }
+
+ global_max_logit = sub_group_reduce_max(global_max_logit);
+
+ SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
+ SOFTMAX_ACCUMULATOR_TYPE partition_adjusted_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
+ const uint exp_sums_base_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num;
+ for (uint i = 0; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) {
+ const uint partition_offset = i * SUBGROUP_SIZE + sglid;
+ SOFTMAX_ACCUMULATOR_TYPE exp_sum = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_ZERO : exp_sums[exp_sums_base_offset + partition_offset];
+ SOFTMAX_ACCUMULATOR_TYPE max_logit = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits[max_logits_base_offset + partition_offset];
+ SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = exp_sum * native_exp(max_logit - global_max_logit);
+ global_exp_sum += adjusted_exp_sum;
+
+ // Save and broadcast the adjusted exp_sum for the currently being processed partition
+ if (i == partition_idx / SUBGROUP_SIZE)
+ partition_adjusted_exp_sum = sub_group_broadcast(adjusted_exp_sum, partition_idx % SUBGROUP_SIZE);
+ }
+
+ global_exp_sum = sub_group_reduce_add(global_exp_sum);
+
+ const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx;
+ SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output[input_offset];
+
+ softmax_value = softmax_value * partition_adjusted_exp_sum / global_exp_sum;
+ total_score += softmax_value;
+ }
+ }
+
+ const uint output_offset = subsequence_offsets[subsequence_idx];
+ if (partition_global_idx < seq_len) {
+ scores_output[output_offset + partition_global_idx] = total_score;
+ }
+}
+
+#undef MAX_PARTITIONS_NUM
+#endif
diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/reorder_data_b_fs_yx_fsv16_fsv32_to_bfyx.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/reorder_data_b_fs_yx_fsv16_fsv32_to_bfyx.cl
index 95f0d0ff399a3b..ee27d220e30ce9 100644
--- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/reorder_data_b_fs_yx_fsv16_fsv32_to_bfyx.cl
+++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/reorder_data_b_fs_yx_fsv16_fsv32_to_bfyx.cl
@@ -66,10 +66,7 @@ KERNEL (reorder_data_b_fs_yx_fsv16_fsv32_to_bfyx)(
#if (TILE_SIZE == DEFAULT_TILE_SIZE)
- // read
- INPUTVTYPE read_data = AS_INPUTVTYPE(_sub_group_block_read8((const __global uint*)(input) + input_idx_tile));
-
- // write
+ // write index
const uint output_idx = OUTPUT_GET_TILED_INDEX(OUTPUT_TILED_ORDER);
if (F_NO_REMAINDER_CONDITION
@@ -79,13 +76,25 @@ KERNEL (reorder_data_b_fs_yx_fsv16_fsv32_to_bfyx)(
) {
#ifdef X_REMAINDER_SIZE
if (X_REMAINDER_CONDITION) {
+ // read
+ INPUTVTYPE read_data;
+ for (int j = 0; j < X_REMAINDER_SIZE; ++j) {
+ read_data[j] = AS_INPUT0_TYPE(_sub_group_block_read((const __global uint*)(input) + input_idx_tile + j * DEFAULT_STRIDE));
+ }
+ // write
for (int i = 0 ; i < X_REMAINDER_SIZE; i++) {
output[output_idx + i] = TO_OUTPUT_TYPE(read_data[i]);
}
} else {
+ // read
+ INPUTVTYPE read_data = AS_INPUTVTYPE(_sub_group_block_read8((const __global uint*)(input) + input_idx_tile));
+ // write
VSTORE(TO_OUTPUTVTYPE(read_data), 0, output + output_idx);
}
#else
+ // read
+ INPUTVTYPE read_data = AS_INPUTVTYPE(_sub_group_block_read8((const __global uint*)(input) + input_idx_tile));
+ // write
VSTORE(TO_OUTPUTVTYPE(read_data), 0, output + output_idx);
#endif
}
diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
index 55f87e4189d9fe..cddafe62623d9e 100644
--- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
+++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
@@ -842,6 +842,14 @@ KERNEL(sdpa_opt)(
const __global int* blocked_indexes_start,
const __global int* blocked_indexes_end,
const __global int* gws_seq_indexes_correspondence
+#if PAGED_ATTENTION_SCORES_OUTPUT
+ , __global SOFTMAX_ACCUMULATOR_TYPE* softmax_results
+ , const __global int* subsequence_offsets
+ , __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums
+ , __global SOFTMAX_ACCUMULATOR_TYPE* max_logits
+ , __global OUTPUT_TYPE* tmp_out
+ , const uint aligned_max_context_len
+#endif
#else
__global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
__global SOFTMAX_ACCUMULATOR_TYPE* max_logits,
@@ -1222,6 +1230,39 @@ KERNEL(sdpa_opt)(
slm_qk_vals[sglid * SEQ_LEN_PARTITION_SIZE + sgid * TARGET_SEQ_LEN_BLOCK_SIZE + i] = qk_acc[i];
}
+#if PAGED_ATTENTION_SCORES_OUTPUT
+ const uint subsequence_idx = gws_seq_indexes_correspondence[target_seq_dim];
+ const uint subsequence_end_pos = subsequence_begins[subsequence_idx + 1];
+ const uint block_start_pos = blocked_indexes_start[target_seq_dim];
+ const uint block_end_pos = blocked_indexes_end[target_seq_dim];
+
+ // PagedAttention is supposed to save only last "row" of the QK matrix multiplication,
+ // so save SEQ_LEN_PARTITION_SIZE elements for each partition
+ if (subsequence_end_pos == block_end_pos) {
+ const uint last_row_idx = block_end_pos - block_start_pos - 1;
+ if (sglid == last_row_idx) {
+ const uint partition_idx = start_partition_idx / SEQ_LEN_PARTITION_SIZE;
+
+ if (sgid == 0) {
+ const uint max_partitions_num = aligned_max_context_len / SEQ_LEN_PARTITION_SIZE;
+ const uint exp_sums_output_offset = subsequence_idx * NUM_HEADS * max_partitions_num +
+ num_heads_dim * max_partitions_num +
+ partition_idx;
+ exp_sums[exp_sums_output_offset] = exp_sum_new;
+ max_logits[exp_sums_output_offset] = qk_max_new;
+ }
+
+ const uint output_offset = subsequence_idx * NUM_HEADS * aligned_max_context_len +
+ num_heads_dim * aligned_max_context_len +
+ partition_idx * SEQ_LEN_PARTITION_SIZE + sgid * TARGET_SEQ_LEN_BLOCK_SIZE;
+ for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) {
+ softmax_results[output_offset + i] = qk_acc[i];
+ }
+
+ }
+ }
+#endif
+
barrier(CLK_LOCAL_MEM_FENCE);
}
diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_update_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_update_kernel_ref.cpp
index ddfb491f50278a..ce20f49de597ff 100644
--- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_update_kernel_ref.cpp
+++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_update_kernel_ref.cpp
@@ -167,7 +167,7 @@ void KVCacheUpdateKernelRef::GetUpdateDispatchDataFunc(KernelData& kd) const {
const auto indexes_dt = Datatype::INT32;
const auto target_seq_len_block_size = 16;
- const auto target_seq_len = prim_params.conf.paged_attention_aligned_seq_len;
+ const auto target_seq_len = std::max(prim_params.conf.paged_attention_aligned_seq_len, static_cast(1));
const auto indexes_buf_size = CeilDiv(target_seq_len, target_seq_len_block_size) * BytesPerElement(indexes_dt);
kd.internalBufferSizes.clear();
diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp
index 63c5e74160f652..909a40d677f535 100644
--- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp
+++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp
@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//
+#include "sdpa_kernel_opt.h"
#include "pa_sdpa_kernel_opt.h"
#include "kernel_selector_params.h"
@@ -15,6 +16,7 @@ enum KernelsTypes {
MULTI_TOKENS,
FINALIZATION,
FINALIZATION_MULTI_TOKENS,
+ SCORES_CALCULATION,
TOTAL_KERNELS_NUM
};
@@ -35,6 +37,8 @@ static std::string GetKernelName(std::string base_name, KernelsTypes type) {
kernel_name += "_finalization";
} else if (type == KernelsTypes::FINALIZATION_MULTI_TOKENS) {
kernel_name += "_finalization_multi_tokens_seq";
+ } else if (type == KernelsTypes::SCORES_CALCULATION) {
+ kernel_name += "_scores_calculation";
}
return kernel_name;
@@ -46,10 +50,15 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const {
}
const auto& params = static_cast(p);
- const std::vector kernels_type = { KernelsTypes::SINGLE_TOKEN,
- KernelsTypes::MULTI_TOKENS,
- KernelsTypes::FINALIZATION,
- KernelsTypes::FINALIZATION_MULTI_TOKENS };
+ std::vector kernels_type = { KernelsTypes::SINGLE_TOKEN,
+ KernelsTypes::MULTI_TOKENS,
+ KernelsTypes::FINALIZATION,
+ KernelsTypes::FINALIZATION_MULTI_TOKENS };
+
+ const auto has_scores_output = params.outputs.size() > 1;
+ if (has_scores_output) {
+ kernels_type.push_back(KernelsTypes::SCORES_CALCULATION);
+ }
KernelData kd = KernelData::Default(params, kernels_type.size());
kd.needs_sub_kernels_sync = true;
@@ -65,7 +74,8 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const {
const auto jit = CreateJit(kernel_name, jit_constants, entry_point);
- size_t inputs_num = static_cast(params.inputs.size());
+ int inputs_num = static_cast(params.inputs.size());
+ int outputs_num = 1;
if (kernel_type == KernelsTypes::SINGLE_TOKEN) {
// SINGLE_TOKEN kernel doesn't use the subsequence_begins input
inputs_num -= 1;
@@ -75,6 +85,11 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const {
} else if (kernel_type == KernelsTypes::FINALIZATION_MULTI_TOKENS) {
// FINALIZATION_MULTI_TOKENS kernel uses past_lens data input and subsequence_begins
inputs_num = 2;
+ } else if (kernel_type == KernelsTypes::SCORES_CALCULATION) {
+ // SCORES_CALCULATION kernel uses past_lens data input and subsequence_begins
+ inputs_num = 2;
+ // Output is configured manually to use the second output memory buffer
+ outputs_num = 0;
}
auto& kernel = kd.kernels[kd_kernels_idx++];
@@ -87,19 +102,33 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const {
{},
false,
false,
- static_cast(inputs_num),
+ inputs_num,
GetFusedPrimitiveInputsCount(params),
- static_cast(params.outputs.size()),
+ outputs_num,
params.is_shape_agnostic);
- kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0});
- kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1});
- kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2});
+ if (kernel_type == KernelsTypes::SCORES_CALCULATION) {
+ kernel.params.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 1});
+ }
+
+ uint32_t internal_buffers_num = 0;
+ if (has_scores_output) {
+ // Intermediate softmax results for scores output calculation and precalculated accumulated
+ // sequence length offsets for each subsequence
+ internal_buffers_num += 2;
+ }
+
+ // Softmax's exp_sums, max_logits and intermediate output
+ internal_buffers_num += 3;
if (kernel_type == KernelsTypes::MULTI_TOKENS || kernel_type == KernelsTypes::FINALIZATION_MULTI_TOKENS) {
// MULTIPLE_TOKENS kernels needs additional information related to mapping
// launched kernel instances to subsequence indexes
- kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 3});
+ internal_buffers_num++;
+ }
+
+ for (uint32_t i = 0; i < internal_buffers_num; i++) {
+ kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, i});
}
if (kernel_type == KernelsTypes::FINALIZATION || kernel_type == KernelsTypes::FINALIZATION_MULTI_TOKENS) {
@@ -108,6 +137,15 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const {
// Remove unused shape_info argument at finalization stage
kernel.params.arguments.erase(kernel.params.arguments.begin());
}
+
+ if (kernel_type == KernelsTypes::SCORES_CALCULATION) {
+ // The scores kernel needs to know if the current execution mode is mixed or ordinary
+ // to configure proper memory access
+ kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0});
+
+ // Remove unused shape_info argument for scores kernel
+ kernel.params.arguments.erase(kernel.params.arguments.begin());
+ }
}
return {kd};
@@ -173,7 +211,12 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", config.group_size));
}
- auto sdpa_stage = kernel_idx == KernelsTypes::FINALIZATION || kernel_idx == KernelsTypes::FINALIZATION_MULTI_TOKENS ? 1 : 0;
+ auto sdpa_stage = 0;
+ if (kernel_idx == KernelsTypes::FINALIZATION || kernel_idx == KernelsTypes::FINALIZATION_MULTI_TOKENS) {
+ sdpa_stage = 1;
+ } else if (kernel_idx == KernelsTypes::SCORES_CALCULATION) {
+ sdpa_stage = 2;
+ }
jit.AddConstant(MakeJitConstant("SDPA_STAGE_" + std::to_string(sdpa_stage), 1));
if (config.has_const_scale_val) {
@@ -190,6 +233,10 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
jit.Merge(MakeTypeJitConstants(params.inputs[alibi_input_idx].GetDType(), "ALIBI_INPUT"));
}
+ if (params.outputs.size() > 1) {
+ jit.AddConstant(MakeJitConstant("PAGED_ATTENTION_SCORES_OUTPUT", 1));
+ }
+
if (kernel_idx == KernelsTypes::MULTI_TOKENS || kernel_idx == KernelsTypes::FINALIZATION_MULTI_TOKENS)
jit.AddConstant(MakeJitConstant("MULTI_TOKENS_PROCESSING", 1));
@@ -203,18 +250,36 @@ CommonDispatchData PagedAttentionSDPAKernelOpt::SetDefault(const pa_sdpa_params&
const auto& input = params.inputs[0];
if (!input.is_dynamic()) {
- const size_t sequences_number = input.Batch().v;
- const size_t num_of_partitions = CeilDiv(params.max_context_len, seq_len_partition_size);
+ const size_t total_tokens = input.Batch().v;
+ const size_t num_of_partitions = CeilDiv(params.conf.paged_attention_max_len, seq_len_partition_size);
const size_t heads_num = static_cast(params.conf.heads_num);
const size_t head_size = static_cast(params.conf.head_size);
- if (kernel_idx == 0) {
- dispatch_data.gws = { sequences_number,
+ if (kernel_idx == KernelsTypes::SINGLE_TOKEN || kernel_idx == KernelsTypes::MULTI_TOKENS) {
+ dispatch_data.gws = { total_tokens,
heads_num,
head_size * num_of_partitions };
dispatch_data.lws = { 1, 1, head_size };
+ } else if (kernel_idx == KernelsTypes::SCORES_CALCULATION) {
+ const auto& past_lens = params.inputs[3];
+ const auto subsequences_number = past_lens.Batch().v;
+
+ size_t partition_size = 0;
+ size_t num_of_partitions = 0;
+ if (params.stage == PagedAttentionStage::PREFILL) {
+ partition_size = SDPAKernelOpt::get_seq_len_partition_size(params, params.conf.head_size, 1);
+ } else {
+ partition_size = seq_len_partition_size;
+ }
+
+ num_of_partitions = CeilDiv(params.conf.paged_attention_max_len, partition_size);
+
+ dispatch_data.gws = { partition_size * num_of_partitions,
+ 1,
+ subsequences_number };
+ dispatch_data.lws = { partition_size, 1, 1 };
} else {
- dispatch_data.gws = { sequences_number,
+ dispatch_data.gws = { total_tokens,
heads_num,
head_size };
dispatch_data.lws = { 1, 1, subgroup_size };
@@ -228,30 +293,39 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
kd.update_dispatch_data_func = [](const Params& params, KernelData& kd) {
const auto& prim_params = static_cast(params);
- const size_t expected_kernels_num = 4;
- OPENVINO_ASSERT(kd.kernels.size() == expected_kernels_num, "[GPU] Invalid kernels size for update dispatch data func of SDPA kernel");
+ const auto has_scores_output = prim_params.outputs.size() > 1;
+ const auto expected_kernels_num = has_scores_output ? KernelsTypes::TOTAL_KERNELS_NUM : KernelsTypes::TOTAL_KERNELS_NUM - 1;
+ OPENVINO_ASSERT(kd.kernels.size() == static_cast(expected_kernels_num),
+ "[GPU] Invalid kernels size for update dispatch data func of SDPA kernel");
+
+ const auto scores_calc_only = prim_params.stage == PagedAttentionStage::PREFILL && has_scores_output;
+ const auto multi_tokens_mode = prim_params.stage == PagedAttentionStage::MIXED;
auto dispatch_data1 = SetDefault(prim_params, KernelsTypes::SINGLE_TOKEN);
kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.global = dispatch_data1.gws;
kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.local = dispatch_data1.lws;
- kd.kernels[KernelsTypes::SINGLE_TOKEN].skip_execution = prim_params.multi_tokens_mode;
+ kd.kernels[KernelsTypes::SINGLE_TOKEN].skip_execution = multi_tokens_mode || scores_calc_only;
kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.global = dispatch_data1.gws;
kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.local = dispatch_data1.lws;
- kd.kernels[KernelsTypes::MULTI_TOKENS].skip_execution = !prim_params.multi_tokens_mode;
+ kd.kernels[KernelsTypes::MULTI_TOKENS].skip_execution = !multi_tokens_mode || scores_calc_only;
- const auto& input = prim_params.inputs[0];
- const size_t sequences_number = input.Batch().v;
- const size_t num_of_partitions = CeilDiv(prim_params.max_context_len, seq_len_partition_size);
+ size_t partition_size = 0;
+ if (prim_params.stage == PagedAttentionStage::PREFILL) {
+ partition_size = SDPAKernelOpt::get_seq_len_partition_size(params, prim_params.conf.head_size, 1);
+ } else {
+ partition_size = seq_len_partition_size;
+ }
+ const size_t num_of_partitions = CeilDiv(prim_params.conf.paged_attention_max_len, partition_size);
auto dispatch_data2 = SetDefault(prim_params, KernelsTypes::FINALIZATION);
kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.global = dispatch_data2.gws;
kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.local = dispatch_data2.lws;
- kd.kernels[KernelsTypes::FINALIZATION].skip_execution = num_of_partitions == 1 || prim_params.multi_tokens_mode;
+ kd.kernels[KernelsTypes::FINALIZATION].skip_execution = num_of_partitions == 1 || multi_tokens_mode || scores_calc_only;
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.global = dispatch_data2.gws;
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.local = dispatch_data2.lws;
- kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].skip_execution = num_of_partitions == 1 || !prim_params.multi_tokens_mode;
+ kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].skip_execution = num_of_partitions == 1 || !multi_tokens_mode || scores_calc_only;
ScalarDescriptor num_of_partitions_scalar;
num_of_partitions_scalar.t = ScalarDescriptor::Types::UINT32;
@@ -261,23 +335,63 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.scalars.resize(1);
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.scalars[0] = num_of_partitions_scalar;
+ if (has_scores_output) {
+ auto dispatch_data = SetDefault(prim_params, KernelsTypes::SCORES_CALCULATION);
+ kd.kernels[KernelsTypes::SCORES_CALCULATION].params.workGroups.global = dispatch_data.gws;
+ kd.kernels[KernelsTypes::SCORES_CALCULATION].params.workGroups.local = dispatch_data.lws;
+ kd.kernels[KernelsTypes::SCORES_CALCULATION].skip_execution = false;
+
+ ScalarDescriptor is_mixed_mode;
+ is_mixed_mode.t = ScalarDescriptor::Types::UINT32;
+ is_mixed_mode.v.u32 = static_cast(multi_tokens_mode);
+ kd.kernels[KernelsTypes::SCORES_CALCULATION].params.scalars.resize(1);
+ kd.kernels[KernelsTypes::SCORES_CALCULATION].params.scalars[0] = is_mixed_mode;
+ }
+
+ const auto& input = prim_params.inputs[0];
+ const size_t total_tokens = input.Batch().v;
+
auto buf_dt_size = BytesPerElement(softmax_acc_dt);
- auto buf_elements_count = sequences_number * prim_params.conf.heads_num * num_of_partitions;
+ auto buf_elements_count = total_tokens * prim_params.conf.heads_num * num_of_partitions;
auto buf_size = buf_elements_count * buf_dt_size;
auto tmp_out_dt_size = BytesPerElement(softmax_acc_dt);
- auto tmp_out_elements_count = sequences_number * prim_params.conf.heads_num * prim_params.conf.head_size * num_of_partitions;
+ auto tmp_out_elements_count = total_tokens * prim_params.conf.heads_num * prim_params.conf.head_size * num_of_partitions;
auto tmp_out_size = tmp_out_elements_count * tmp_out_dt_size;
kd.internalBufferSizes.clear();
- kd.internalBufferSizes.push_back(buf_size);
- kd.internalBufferSizes.push_back(buf_size);
- kd.internalBufferSizes.push_back(tmp_out_size);
+
+ if (has_scores_output) {
+ const auto& past_lens = prim_params.inputs[3];
+ auto subsequences_number = past_lens.Batch().v;
+ auto softmax_buf_dt_size = BytesPerElement(softmax_acc_dt);
+
+ auto softmax_buf_elements_count = subsequences_number * prim_params.conf.heads_num * num_of_partitions * partition_size;
+ auto softmax_buf_size = softmax_buf_elements_count * softmax_buf_dt_size;
+
+ // Softmax intermediate output
+ kd.internalBufferSizes.push_back(softmax_buf_size);
+ // Precalculated accumulated sequence length offsets for each subsequence
+ kd.internalBufferSizes.push_back(subsequences_number * BytesPerElement(Datatype::INT32));
+
+ if (prim_params.stage == PagedAttentionStage::PREFILL) {
+ // Recalculate buf_size as in case of PREFILL stage it's not needed to allocate buffer per each input token
+ buf_elements_count = subsequences_number * prim_params.conf.heads_num * num_of_partitions;
+ buf_size = buf_elements_count * buf_dt_size;
+
+ // Intermediate tmp output buffer is not used for PREFILL stage
+ tmp_out_size = tmp_out_dt_size;
+ }
+ }
+
+ kd.internalBufferSizes.push_back(buf_size); // softmax exp_sums
+ kd.internalBufferSizes.push_back(buf_size); // softmax max_logits
+ kd.internalBufferSizes.push_back(tmp_out_size); // intermediate output
kd.internalBufferDataType = softmax_acc_dt;
- if (prim_params.multi_tokens_mode) {
+ if (multi_tokens_mode) {
auto buf_dt_size = BytesPerElement(Datatype::INT32);
- auto buf_elements_count = sequences_number;
+ auto buf_elements_count = total_tokens;
auto buf_size = Align(buf_elements_count * buf_dt_size, BytesPerElement(softmax_acc_dt));
kd.internalBufferSizes.push_back(buf_size);
}
diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.h
index a2456ccd9e2af5..a52571b03691df 100644
--- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.h
+++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.h
@@ -9,11 +9,17 @@
namespace kernel_selector {
+enum PagedAttentionStage {
+ GENERATE = 0,
+ PREFILL = 1,
+ MIXED = 2,
+ UNKNOWN = 3
+};
+
struct pa_sdpa_params : base_params {
pa_sdpa_params() : base_params(KernelType::PA_SDPA) {}
- bool multi_tokens_mode = false;
- size_t max_context_len = 0;
+ PagedAttentionStage stage = PagedAttentionStage::UNKNOWN;
sdpa_configuration conf;
};
diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h
index 5cd9c384ff2709..8fcc4a16692d6c 100644
--- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h
+++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h
@@ -97,6 +97,7 @@ struct sdpa_configuration {
bool is_paged_attention = false;
int64_t paged_attention_aligned_seq_len = -1;
int64_t paged_attention_block_size = 0;
+ int64_t paged_attention_max_len = 0;
bool has_const_scale_val = false;
float scale_val = 0.f;
};
diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp
index 4e71064efbc895..4c23d4de4fd68d 100644
--- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp
+++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp
@@ -21,38 +21,11 @@ enum KernelsTypes {
constexpr size_t subgroup_size = 16;
} // namespace
-static size_t get_sg_number_scale_factor(const sdpa_params& sdpa_params, size_t kernel_type) {
- const size_t optimal_scale_factor = 2;
- if (kernel_type == KernelsTypes::MULTI_TOKENS) {
- if (sdpa_params.conf.head_size * optimal_scale_factor <= sdpa_params.engineInfo.maxWorkGroupSize) {
- return optimal_scale_factor;
- }
- } else if (kernel_type == KernelsTypes::SINGLE_TOKEN) {
- if (sdpa_params.conf.head_size * optimal_scale_factor <= sdpa_params.engineInfo.maxWorkGroupSize &&
- sdpa_params.conf.head_size * optimal_scale_factor / subgroup_size <= subgroup_size) {
- return optimal_scale_factor;
- }
- }
-
- return 1;
-}
-
static size_t get_target_seq_len_block_size() {
const size_t block_size = 16;
return block_size;
}
-static size_t get_seq_len_partition_size(const sdpa_params& sdpa_params, size_t kernel_type) {
- size_t seq_len = 0;
- if (kernel_type == KernelsTypes::MULTI_TOKENS) {
- seq_len = sdpa_params.conf.head_size * get_sg_number_scale_factor(sdpa_params, kernel_type);
- } else {
- seq_len = 256;
- }
-
- return seq_len;
-}
-
static Datatype get_softmax_acc_type() {
return Datatype::F32;
}
@@ -71,7 +44,7 @@ static size_t get_partitions_num(const sdpa_params& sdpa_params, size_t kernel_t
TransposedDimensionAccessHelperBase dims_k(sdpa_params.inputs[1], sdpa_params.input1_order);
auto source_seq_len = dims_k.y_dim().v;
- return CeilDiv(source_seq_len, get_seq_len_partition_size(sdpa_params, kernel_type));
+ return CeilDiv(source_seq_len, SDPAKernelOpt::get_seq_len_partition_size(sdpa_params, sdpa_params.conf.head_size, kernel_type));
}
static std::vector get_internal_buffer_sizes(const sdpa_params& sdpa_params, size_t kernel_type) {
@@ -130,6 +103,33 @@ static std::string GetKernelName(std::string base_name, KernelsTypes type, const
return kernel_name;
}
+size_t SDPAKernelOpt::get_sg_number_scale_factor(const Params& params, size_t head_size, size_t kernel_type) {
+ const size_t optimal_scale_factor = 2;
+ if (kernel_type == KernelsTypes::MULTI_TOKENS) {
+ if (head_size * optimal_scale_factor <= params.engineInfo.maxWorkGroupSize) {
+ return optimal_scale_factor;
+ }
+ } else if (kernel_type == KernelsTypes::SINGLE_TOKEN) {
+ if (head_size * optimal_scale_factor <= params.engineInfo.maxWorkGroupSize &&
+ head_size * optimal_scale_factor / subgroup_size <= subgroup_size) {
+ return optimal_scale_factor;
+ }
+ }
+
+ return 1;
+}
+
+size_t SDPAKernelOpt::get_seq_len_partition_size(const Params& params, size_t head_size, size_t kernel_type) {
+ size_t seq_len = 0;
+ if (kernel_type == KernelsTypes::MULTI_TOKENS) {
+ seq_len = head_size * get_sg_number_scale_factor(params, head_size, kernel_type);
+ } else {
+ seq_len = 256;
+ }
+
+ return seq_len;
+}
+
ParamsKey SDPAKernelOpt::GetSupportedKey() const {
ParamsKey k;
k.EnableInputDataType(Datatype::INT8);
@@ -176,14 +176,14 @@ JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t ke
const auto& config = params.conf;
jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size));
jit.AddConstant(MakeJitConstant("HEAD_SIZE", config.head_size));
- jit.AddConstant(MakeJitConstant("SEQ_LEN_PARTITION_SIZE", get_seq_len_partition_size(params, kernel_idx)));
+ jit.AddConstant(MakeJitConstant("SEQ_LEN_PARTITION_SIZE", get_seq_len_partition_size(params, config.head_size, kernel_idx)));
auto target_seq_len_block_size = kernel_idx == KernelsTypes::SINGLE_TOKEN ? 1 : get_target_seq_len_block_size();
jit.AddConstant(MakeJitConstant("TARGET_SEQ_LEN_BLOCK_SIZE", target_seq_len_block_size));
auto sdpa_stage = kernel_idx == KernelsTypes::FINALIZATION ? 1 : 0;
jit.AddConstant(MakeJitConstant("SDPA_STAGE_" + std::to_string(sdpa_stage), 1));
- jit.AddConstant(MakeJitConstant("SG_SCALE_FACTOR", get_sg_number_scale_factor(params, kernel_idx)));
+ jit.AddConstant(MakeJitConstant("SG_SCALE_FACTOR", get_sg_number_scale_factor(params, config.head_size, kernel_idx)));
if (params.conf.is_paged_attention) {
if (params.conf.has_alibi_input) {
@@ -196,6 +196,10 @@ JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t ke
} else {
jit.AddConstant(MakeJitConstant("HAS_SCALE_INPUT", 1));
}
+
+ if (params.outputs.size() > 1) {
+ jit.AddConstant(MakeJitConstant("PAGED_ATTENTION_SCORES_OUTPUT", 1));
+ }
} else if (params.inputs.size() <= 4) {
jit.AddConstant(MakeJitConstant("STATIC_SCALE_VALUE_INV", std::sqrt(static_cast(params.conf.head_size))));
jit.AddConstant(MakeJitConstant("STATIC_SCALE_VALUE", 1.0f / std::sqrt(static_cast(params.conf.head_size))));
@@ -218,11 +222,11 @@ CommonDispatchData SDPAKernelOpt::SetDefault(const sdpa_params& params, size_t k
if (params.conf.is_paged_attention) {
OPENVINO_ASSERT(kernel_idx == KernelsTypes::MULTI_TOKENS);
- const size_t sg_num_scale = get_sg_number_scale_factor(params, kernel_idx);
const size_t heads_num = static_cast(params.conf.heads_num);
+ const size_t head_size = static_cast(params.conf.head_size);
+ const size_t sg_num_scale = get_sg_number_scale_factor(params, head_size, kernel_idx);
const size_t target_seq_len_block_size = get_target_seq_len_block_size();
const size_t target_seq_len = static_cast(params.conf.paged_attention_aligned_seq_len);
- const size_t head_size = static_cast(params.conf.head_size);
dispatch_data.gws = { heads_num,
CeilDiv(target_seq_len, target_seq_len_block_size),
@@ -243,13 +247,13 @@ CommonDispatchData SDPAKernelOpt::SetDefault(const sdpa_params& params, size_t k
const size_t target_seq_len_block_size = kernel_idx == 1 ? get_target_seq_len_block_size() : 1;
if (kernel_idx == KernelsTypes::SINGLE_TOKEN) {
- const size_t sg_num_scale = get_sg_number_scale_factor(params, kernel_idx);
+ const size_t sg_num_scale = get_sg_number_scale_factor(params, head_size, kernel_idx);
dispatch_data.gws = { batch_size * heads_num,
CeilDiv(target_seq_len, target_seq_len_block_size),
head_size * num_of_partitions * sg_num_scale };
dispatch_data.lws = { 1, 1, head_size * sg_num_scale };
} else if (kernel_idx == KernelsTypes::MULTI_TOKENS) {
- const size_t sg_num_scale = get_sg_number_scale_factor(params, kernel_idx);
+ const size_t sg_num_scale = get_sg_number_scale_factor(params, head_size, kernel_idx);
dispatch_data.gws = { batch_size * heads_num,
CeilDiv(target_seq_len, target_seq_len_block_size),
head_size * sg_num_scale };
@@ -317,7 +321,7 @@ KernelsData SDPAKernelOpt::GetKernelsData(const Params& params) const {
false,
inputs_num,
GetFusedPrimitiveInputsCount(params),
- static_cast(prim_params.outputs.size()),
+ 1 /* number_of_outputs */,
prim_params.is_shape_agnostic);
auto beam_table_idx = prim_params.inputs.size();
@@ -339,6 +343,19 @@ KernelsData SDPAKernelOpt::GetKernelsData(const Params& params) const {
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2});
+ if (prim_params.conf.is_paged_attention && prim_params.outputs.size() > 1) {
+ // Intermediate buffers for PagedAttention scores calculation:
+ // softmax_results, subsequence_offsets, exp_sums, max_logits, tmp_out
+ kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 3});
+ kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4});
+ kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 5});
+ kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 6});
+ kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 7});
+
+ // Scalar used for proper offset calculation of intermediate data buffers
+ kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0});
+ }
+
const auto buf_sizes = get_internal_buffer_sizes(prim_params, kernel_idx);
if (!prim_params.conf.is_paged_attention) {
kd.internalBufferSizes.clear();
@@ -379,6 +396,15 @@ void SDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) const {
kernel_data.kernels[0].params.workGroups.global = dispatch_data.gws;
kernel_data.kernels[0].params.workGroups.local = dispatch_data.lws;
kernel_data.kernels[0].skip_execution = false;
+
+ if (prim_params.outputs.size() > 1) {
+ const auto max_seq_len = prim_params.conf.paged_attention_max_len;
+ const auto seq_len_partition_size = get_seq_len_partition_size(params, prim_params.conf.head_size, KernelsTypes::MULTI_TOKENS);
+
+ kernel_data.kernels[0].params.scalars.resize(1);
+ kernel_data.kernels[0].params.scalars[0].t = ScalarDescriptor::Types::UINT32;
+ kernel_data.kernels[0].params.scalars[0].v.u32 = static_cast(Align(max_seq_len, seq_len_partition_size));
+ }
} else {
const auto num_of_partitions = get_partitions_num(prim_params, KernelsTypes::SINGLE_TOKEN);
const auto buf_sizes = get_internal_buffer_sizes(prim_params, KernelsTypes::SINGLE_TOKEN);
diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.h
index 8d7279f5546112..a4d351498d7075 100644
--- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.h
+++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.h
@@ -17,6 +17,9 @@ class SDPAKernelOpt : public SDPAKernelBase {
KernelsPriority GetKernelsPriority(const Params& params) const override;
ParamsKey GetSupportedKey() const override;
+ static size_t get_sg_number_scale_factor(const Params& params, size_t head_size, size_t kernel_type);
+ static size_t get_seq_len_partition_size(const Params& params, size_t head_size, size_t kernel_type);
+
protected:
bool Validate(const Params& p) const override;
void GetUpdateDispatchDataFunc(KernelData& kd) const override;
diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp
index 7425b096b6d324..d82d3a66fed7f7 100644
--- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp
+++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp
@@ -61,10 +61,13 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared
OPENVINO_ASSERT(alibi_const != nullptr);
prim.has_alibi = ov::shape_size(alibi_const->get_output_shape(0)) > 0;
+ prim.num_outputs = 1;
if (op->get_output_size() > 1) {
const auto scores_output_idx = 1;
const auto& users = op->get_output_target_inputs(scores_output_idx);
- OPENVINO_ASSERT(users.size() == 0, "[GPU] PagedAttention implementation doesn't support scores output yet");
+ if (users.size() > 0) {
+ prim.num_outputs++; // Add scores output
+ }
}
p.add_primitive(*op, prim);
diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp
new file mode 100644
index 00000000000000..a32ef3325cd9bc
--- /dev/null
+++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp
@@ -0,0 +1,687 @@
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "test_utils.h"
+#include "random_generator.hpp"
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+using namespace cldnn;
+using namespace ov::intel_gpu;
+using namespace ::tests;
+
+/*
+* PagedAttention inputs:
+* [0]: query
+* shape: [batch_size_in_tokens, num_heads * head_size], type: f16
+* [1]: key
+* shape: [batch_size_in_tokens, num_kv_heads * head_size], type: f16
+* [2]: value
+* shape: [batch_size_in_tokens, num_kv_heads * head_size], type: f16
+* [3]: key_cache
+* shape: [num_blocks, num_kv_heads, head_size, block_size], type: f16
+* [4]: value_cache
+* shape: [num_blocks, num_kv_heads, block_size, head_size], type: f16
+* [5]: past_lens
+* shape: [batch_size_in_sequences], type: i32
+* [6]: subsequence_begins
+* shape: [batch_size_in_sequences + 1], type: i32
+* [7]: block_indices
+* Shape: [num_blocks], type: i32
+* [8]: block_indices_begins
+* Shape: [batch_size_in_sequences + 1], type: i32
+* [9]: scale, optional
+* [10]: sliding_window, optional
+* [11]: alibi_slopes, optional
+* [12]: max_context_len
+* shape: [], type: i32
+*/
+
+struct SubsequenceDescriptor {
+ int num_tokens;
+ int past_len;
+};
+
+struct PagedAttentionManager {
+ int num_heads;
+ int head_size;
+ int block_size;
+ std::vector subsequence_descs;
+
+ // per-subsequence QKV inputs
+ std::vector> query_data; // {[1, num_tokens, num_heads, head_size], ..}
+ std::vector> key_data; // {[1, past_len + num_tokens, num_heads, head_size], ..}
+ std::vector> value_data; // {[1, past_len + num_tokens, num_heads, head_size], ..}
+
+ // common PA inputs
+ std::vector past_lens;
+ std::vector subsequence_begins;
+ std::vector block_indices;
+ std::vector block_indices_begins;
+ std::vector max_context_len;
+
+ cldnn::engine& test_engine;
+ cldnn::stream& test_stream;
+ tests::random_generator& rg;
+
+ PagedAttentionManager(tests::random_generator& rg,
+ cldnn::engine& engine,
+ cldnn::stream& stream,
+ const std::vector& subsequence_descs,
+ int num_heads,
+ int head_size,
+ int block_size)
+ : num_heads(num_heads)
+ , head_size(head_size)
+ , block_size(block_size)
+ , subsequence_descs(subsequence_descs)
+ , test_engine(engine)
+ , test_stream(stream)
+ , rg(rg) {
+ // init subsequence_begins and block_indices_begins
+ subsequence_begins.push_back(0);
+ block_indices_begins.push_back(0);
+
+ int max_len = 0;
+ for (int i = 0; i < static_cast(subsequence_descs.size()); i++) {
+ const auto& subsequence_desc = subsequence_descs[i];
+ max_len = std::max(max_len, subsequence_desc.num_tokens + subsequence_desc.past_len);
+
+ query_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens, head_size));
+ key_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, head_size));
+ value_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, head_size));
+
+ past_lens.push_back(subsequence_desc.past_len);
+ int subsequence_start_pos = subsequence_begins[i];
+ int subsequence_end_pos = subsequence_start_pos + subsequence_desc.num_tokens;
+ subsequence_begins.push_back(subsequence_end_pos);
+
+ int subsequence_length = subsequence_desc.num_tokens + subsequence_desc.past_len;
+ int required_blocks = ceil_div(subsequence_length, block_size);
+ int start_block_idx = block_indices.empty() ? 0 : block_indices.back() + 1;
+ int end_block_idx = start_block_idx + required_blocks;
+ for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) {
+ block_indices.push_back(block_idx);
+ }
+
+ int block_indices_start_pos = block_indices_begins[i];
+ int block_indices_end_pos = block_indices_start_pos + required_blocks;
+ block_indices_begins.push_back(block_indices_end_pos);
+ }
+ max_context_len.push_back(max_len);
+ }
+
+ memory::ptr get_query_memory() {
+ return get_QKV_memory(query_data, false);
+ }
+
+ memory::ptr get_key_memory() {
+ return get_QKV_memory(key_data, true);
+ }
+
+ memory::ptr get_value_memory() {
+ return get_QKV_memory(value_data, true);
+ }
+
+ memory::ptr get_key_cache_memory() {
+ auto num_blocks = block_indices.back() + 1;
+ auto key_cache_shape = ov::PartialShape{ num_blocks, num_heads, head_size, block_size };
+ auto key_cache_layout = layout{ key_cache_shape, data_types::f16, format::bfyx };
+ auto memory = test_engine.allocate_memory(key_cache_layout);
+
+ for (int i = 0; i < static_cast(subsequence_descs.size()); i++) {
+ int past_len = subsequence_descs[i].past_len;
+ if (past_len != 0) {
+ int blocks_num = ceil_div(past_len, block_size);
+ int start_block_idx = block_indices[block_indices_begins[i]];
+ for (int block_idx = 0; block_idx < blocks_num; block_idx++) {
+ int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size
+ : block_size;
+ for (int token_idx = 0; token_idx < last_token_idx; token_idx++) {
+ for (int head_idx = 0; head_idx < num_heads; head_idx++) {
+ for (int head_size_idx = 0; head_size_idx < head_size; head_size_idx++) {
+ size_t input_token_offset = block_idx * block_size + token_idx;
+ ov::float16* data_ptr = key_data[i].data() +
+ input_token_offset * num_heads * head_size +
+ head_idx * head_size + head_size_idx;
+
+ // shape: [num_blocks, num_heads, head_size, block_size]
+ size_t output_offset = (start_block_idx + block_idx) * num_heads * head_size * block_size +
+ head_idx * head_size * block_size +
+ head_size_idx * block_size +
+ token_idx;
+
+ set_values(test_stream, memory, data_ptr, 1, output_offset);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return memory;
+ }
+
+ memory::ptr get_value_cache_memory() {
+ auto num_blocks = block_indices.back() + 1;
+ auto value_cache_shape = ov::PartialShape{ num_blocks, num_heads, block_size, head_size };
+ auto value_cache_layout = layout{ value_cache_shape, data_types::f16, format::bfyx };
+ auto memory = test_engine.allocate_memory(value_cache_layout);
+
+ for (int i = 0; i < static_cast(subsequence_descs.size()); i++) {
+ int past_len = subsequence_descs[i].past_len;
+ if (past_len != 0) {
+ int blocks_num = ceil_div(past_len, block_size);
+ int start_block_idx = block_indices[block_indices_begins[i]];
+ for (int block_idx = 0; block_idx < blocks_num; block_idx++) {
+ int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size
+ : block_size;
+ for (int token_idx = 0; token_idx < last_token_idx; token_idx++) {
+ for (int head_idx = 0; head_idx < num_heads; head_idx++) {
+ size_t input_token_offset = block_idx * block_size + token_idx;
+ ov::float16* data_ptr = value_data[i].data() +
+ input_token_offset * num_heads * head_size +
+ head_idx * head_size;
+
+ // shape: [num_blocks, num_heads, block_size, head_size]
+ size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * head_size +
+ head_idx * block_size * head_size +
+ token_idx * head_size;
+
+ set_values(test_stream, memory, data_ptr, head_size, output_offset);
+ }
+ }
+ }
+ }
+ }
+
+ return memory;
+ }
+
+ memory::ptr get_past_lens_memory() {
+ return get_memory_from_vec(past_lens);
+ }
+
+ memory::ptr get_subsequence_begins_memory() {
+ return get_memory_from_vec(subsequence_begins);
+ }
+
+ memory::ptr get_block_indices_memory() {
+ return get_memory_from_vec(block_indices);
+ }
+
+ memory::ptr get_block_indices_begins_memory() {
+ return get_memory_from_vec(block_indices_begins);
+ }
+
+ memory::ptr get_scale_memory() {
+ std::vector scale = { ov::float16(get_default_scale()) };
+ return get_memory_from_vec(scale);
+ }
+
+ memory::ptr get_sliding_window_memory() {
+ std::vector sliding_window = { 0 };
+ return get_memory_from_vec(sliding_window);
+ }
+
+ memory::ptr get_alibi_memory() {
+ std::vector alibi;
+ return get_memory_from_vec(alibi);
+ }
+
+ memory::ptr get_max_context_len_memory() {
+ return get_memory_from_vec(max_context_len);
+ }
+
+ float get_default_scale() {
+ return static_cast(1.f / std::sqrt(head_size));
+ }
+
+private:
+ template
+ memory::ptr get_memory_from_vec(std::vector& input_data) {
+ auto data_size = input_data.empty() ? 1 : input_data.size();
+ auto shape = ov::PartialShape{ static_cast(data_size) };
+ auto layout = cldnn::layout{ shape, ov::element::from(), format::bfyx };
+ auto memory = test_engine.allocate_memory(layout);
+
+ if (input_data.empty()) {
+ auto shape = ov::PartialShape{0};
+ auto layout = cldnn::layout{ shape, ov::element::from(), format::bfyx };
+ return test_engine.reinterpret_buffer(*memory, layout);
+ }
+
+ set_values(test_stream, memory, input_data.data(), input_data.size(), 0);
+
+ return memory;
+ }
+
+ memory::ptr get_QKV_memory(std::vector>& input_data, bool skip_past_len) {
+ int total_tokens = 0;
+ for (const auto& subsequence_desc : subsequence_descs)
+ total_tokens += subsequence_desc.num_tokens;
+
+ auto query_shape = ov::PartialShape{ total_tokens, num_heads * head_size };
+ auto query_layout = layout{ query_shape, data_types::f16, format::bfyx };
+ auto memory = test_engine.allocate_memory(query_layout);
+
+ for (int subsequence_idx = 0; subsequence_idx < static_cast(subsequence_descs.size()); subsequence_idx++) {
+ for (int token_idx = 0; token_idx < subsequence_descs[subsequence_idx].num_tokens; token_idx++) {
+ for (int head_idx = 0; head_idx < num_heads; head_idx++) {
+ size_t input_token_offset = token_idx;
+ // as generated data stored in vectors includes past_len, ignore it for KV inputs
+ if (skip_past_len)
+ input_token_offset += subsequence_descs[subsequence_idx].past_len;
+
+ ov::float16* data_ptr = input_data[subsequence_idx].data() +
+ input_token_offset * num_heads * head_size +
+ head_idx * head_size;
+
+ size_t output_token_offset = subsequence_begins[subsequence_idx] + token_idx;
+ size_t output_offset = output_token_offset * num_heads * head_size +
+ head_idx * head_size;
+
+ set_values(test_stream, memory, data_ptr, head_size, output_offset);
+ }
+ }
+ }
+
+ return memory;
+ }
+
+ template
+ static void set_values(stream& stream, memory::ptr mem, T* vals, size_t size, size_t dst_offset) {
+ mem_lock mem_ptr(mem, stream);
+ for (size_t i = 0; i < size; i++) {
+ mem_ptr[dst_offset + i] = vals[i];
+ }
+ }
+
+ static std::vector generate_input_data(tests::random_generator& rg, size_t num_heads, size_t tokens_num, size_t head_size) {
+ const size_t total_elements_num = tokens_num * num_heads * head_size;
+ auto data = rg.generate_random_1d(total_elements_num, -1, 1);
+
+ return data;
+ }
+};
+
+struct PagedAttentionReference {
+ PagedAttentionReference(PagedAttentionManager& pam)
+ : pam(pam)
+ , test_engine(pam.test_engine)
+ , test_stream(pam.test_stream) {}
+
+ std::pair, std::vector> get_reference() {
+ std::vector ref_data_output;
+ std::vector ref_scores_output;
+
+ for (size_t i = 0; i < pam.subsequence_descs.size(); i++) {
+ const auto& subsequence_desc = pam.subsequence_descs[i];
+ const auto kv_seq_len = subsequence_desc.num_tokens + subsequence_desc.past_len;
+ auto subsequence_ref_results = run_reference(pam.query_data[i],
+ pam.key_data[i],
+ pam.value_data[i],
+ subsequence_desc.num_tokens,
+ kv_seq_len,
+ pam.num_heads,
+ pam.head_size,
+ pam.get_default_scale());
+
+ // concatenate all subsequences into one vector
+ ref_data_output.insert(ref_data_output.end(),
+ subsequence_ref_results.first.begin(),
+ subsequence_ref_results.first.end());
+ ref_scores_output.insert(ref_scores_output.end(),
+ subsequence_ref_results.second.begin(),
+ subsequence_ref_results.second.end());
+ }
+
+ return { ref_data_output, ref_scores_output };
+ }
+
+private:
+ std::pair, std::vector>
+ run_reference(const std::vector& query_data,
+ const std::vector& key_data,
+ const std::vector& value_data,
+ int num_queries,
+ int num_keys,
+ int num_heads,
+ int head_size,
+ float scale) {
+ auto query_shape = ov::PartialShape{1, num_queries, num_heads, head_size};
+ auto key_shape = ov::PartialShape{1, num_keys, num_heads, head_size};
+ auto value_shape = ov::PartialShape{1, num_keys, num_heads, head_size};
+
+ auto query_layout = layout{query_shape, data_types::f16, format::bfyx};
+ auto key_layout = layout{key_shape, data_types::f16, format::bfyx};
+ auto value_layout = layout{value_shape, data_types::f16, format::bfyx};
+
+ OPENVINO_ASSERT(query_layout.count() == query_data.size());
+ OPENVINO_ASSERT(key_layout.count() == key_data.size());
+ OPENVINO_ASSERT(value_layout.count() == value_data.size());
+
+ auto query_mem = test_engine.allocate_memory(query_layout);
+ auto key_mem = test_engine.allocate_memory(key_layout);
+ auto value_mem = test_engine.allocate_memory(value_layout);
+ auto mask_mem = get_mask_mem(num_queries, num_keys, num_heads);
+
+ set_values(query_mem, query_data);
+ set_values(key_mem, key_data);
+ set_values(value_mem, value_data);
+
+ topology topology;
+ topology.add(input_layout("query", query_layout),
+ input_layout("key", key_layout),
+ input_layout("value", value_layout),
+ data("mask", mask_mem),
+ permute("query_transposed", input_info("query"), {0, 2, 1, 3}),
+ permute("key_transposed", input_info("key"), {0, 2, 1, 3}),
+ permute("value_transposed", input_info("value"), {0, 2, 1, 3}),
+ gemm("qk_gemm", { input_info("query_transposed"), input_info("key_transposed") }, data_types::f16, false, true, scale),
+ eltwise("eltwise", { input_info("qk_gemm"), input_info("mask") }, eltwise_mode::sum),
+ softmax("softmax", input_info("eltwise"), -1),
+ gemm("qkv_gemm", { input_info("softmax"), input_info("value_transposed") }, data_types::f16, false, false),
+ permute("qkv_gemm_transposed", input_info("qkv_gemm"), {0, 2, 1, 3}),
+ reorder("output_data", input_info("qkv_gemm_transposed"), format::bfyx, data_types::f16),
+ reorder("scores_data", input_info("softmax"), format::bfyx, data_types::f16)
+ );
+
+ ExecutionConfig config = get_test_default_config(test_engine);
+ config.set_property(ov::intel_gpu::optimize_data(true));
+ config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
+
+ network::ptr network = get_network(test_engine, topology, config, get_test_stream_ptr(), false);
+ network->set_input_data("query", query_mem);
+ network->set_input_data("key", key_mem);
+ network->set_input_data("value", value_mem);
+
+ auto outputs = network->execute();
+
+ auto output_data_mem = outputs.at("output_data").get_memory();
+ auto output_scores_mem = outputs.at("scores_data").get_memory();
+
+ return { get_output_data_vec(output_data_mem, num_queries, head_size, num_heads),
+ get_output_scores_vec(output_scores_mem, num_queries, num_keys, num_heads) };
+ }
+
+ std::vector get_output_scores_vec(memory::ptr scores_output,
+ int num_queries,
+ int num_keys,
+ int num_heads) {
+ OPENVINO_ASSERT(scores_output->count() == static_cast(num_heads * num_queries * num_keys));
+
+ std::vector output_scores(num_keys, 0);
+ mem_lock mem_ptr(scores_output, test_stream);
+ for (int head_idx = 0; head_idx < num_heads; head_idx++) {
+ for (int score_idx = 0; score_idx < num_keys; score_idx++) {
+ output_scores[score_idx] += mem_ptr[head_idx * num_queries * num_keys +
+ (num_queries - 1) * num_keys +
+ score_idx];
+ }
+ }
+
+ return output_scores;
+ }
+
+ std::vector get_output_data_vec(memory::ptr data_output,
+ int num_queries,
+ int head_size,
+ int num_heads) {
+ OPENVINO_ASSERT(data_output->count() == static_cast(num_queries * num_heads * head_size));
+
+ std::vector output_data(data_output->count());
+ mem_lock mem_ptr(data_output, test_stream);
+ for (size_t i = 0; i < data_output->count(); i++)
+ output_data[i] = mem_ptr[i];
+
+ return output_data;
+ }
+
+ memory::ptr get_mask_mem(int num_queries, int num_keys, int num_heads) {
+ /*
+ * Two kinds of masks:
+ *
+ * Case 1 (N == K):
+ * num_queries = N
+ * num_keys = K = N
+ * head_size = H
+ * Q [N, H] * K[H, N]
+ * QK [N, N]
+ * 0 1 N
+ * 0 [ 0, MIN, .., MIN ]
+ * 1 [ 0, 0, .., MIN ]
+ * [ .., .., .., MIN ]
+ * N [ 0, 0, .., 0 ]
+ *
+ * Case 2 (N != K):
+ * num_queries = N
+ * num_keys = K
+ * head_size = H
+ * past_len = P = K - N + 1
+ * Q [N, H] * K[H, K]
+ * QK [N, K]
+ * 0 1 2 P .. K
+ * 0 [ 0, 0, 0, MIN, MIN, MIN ]
+ * 1 [ 0, 0, 0, 0, MIN, MIN ]
+ * [ .., .., .., .., .., MIN ]
+ * N [ 0, 0, 0, 0, .., 0 ]
+ *
+ * Shapes:
+ * Q [1, num_heads, num_queries, head_size]
+ * K [1, num_heads, head_size, num_keys]
+ * Q*K [1, num_heads, num_queries, num_keys]
+ */
+
+ auto mask_shape = ov::PartialShape{ 1, 1, num_queries, num_keys };
+ auto mask_layout = layout{mask_shape, data_types::f16, format::bfyx};
+ auto mask_mem = test_engine.allocate_memory(mask_layout);
+
+ int past_len = num_keys - num_queries + 1;
+ mem_lock mem_ptr(mask_mem, test_stream);
+ for (int i = 0; i < num_queries; i++) {
+ for (int j = 0; j < num_keys; j++) {
+ mem_ptr[i * num_keys + j] = j >= past_len + i ? std::numeric_limits::lowest()
+ : ov::float16(0.f);
+ }
+ }
+
+ return mask_mem;
+ }
+
+
+ PagedAttentionManager& pam;
+ cldnn::engine& test_engine;
+ cldnn::stream& test_stream;
+};
+
+template
+struct PagedAttentionTest : public ::testing::TestWithParam {
+public:
+ random_generator rg;
+ cldnn::engine& engine = get_test_engine();
+ float tolerance = 2e-3;
+
+ void SetUp() override {
+ rg.set_seed(GET_SUITE_NAME);
+ }
+
+ void execute(T& p) {
+ PagedAttentionManager pam(rg, get_test_engine(), get_test_stream(), p.subsequences, p.num_heads, p.head_size, p.block_size);
+
+ auto query_mem = pam.get_query_memory();
+ auto key_mem = pam.get_key_memory();
+ auto value_mem = pam.get_value_memory();
+
+ auto key_cache_mem = pam.get_key_cache_memory();
+ auto value_cache_mem = pam.get_value_cache_memory();
+
+ auto past_lens_mem = pam.get_past_lens_memory();
+ auto subsequence_begins_mem = pam.get_subsequence_begins_memory();
+ auto block_indices_mem = pam.get_block_indices_memory();
+ auto block_indices_begins_mem = pam.get_block_indices_begins_memory();
+
+ auto scale_mem = pam.get_scale_memory();
+ auto sliding_window_mem = pam.get_sliding_window_memory();
+ auto alibi_mem = pam.get_alibi_memory();
+ auto max_context_len_mem = pam.get_max_context_len_memory();
+
+ auto query_layout = query_mem->get_layout();
+ auto key_layout = key_mem->get_layout();
+ auto value_layout = value_mem->get_layout();
+ auto key_cache_layout = key_cache_mem->get_layout();
+ auto value_cache_layout = value_cache_mem->get_layout();
+ auto past_lens_layout = past_lens_mem->get_layout();
+ auto subsequence_begins_layout = subsequence_begins_mem->get_layout();
+ auto block_indices_layout = block_indices_mem->get_layout();
+ auto block_indices_begins_layout = block_indices_begins_mem->get_layout();
+ auto scale_layout = scale_mem->get_layout();
+ auto sliding_window_layout = sliding_window_mem->get_layout();
+ auto alibi_layout = alibi_mem->get_layout();
+ auto max_context_len_layout = max_context_len_mem->get_layout();
+
+ // make layouts dynamic
+ query_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.head_size });
+ key_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.head_size });
+ value_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.head_size });
+ key_cache_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads, p.head_size, p.block_size });
+ value_cache_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads, p.block_size, p.head_size });
+ past_lens_layout.set_partial_shape(ov::PartialShape{ -1 });
+ subsequence_begins_layout.set_partial_shape(ov::PartialShape{ -1 });
+ block_indices_layout.set_partial_shape(ov::PartialShape{ -1 });
+ block_indices_begins_layout.set_partial_shape(ov::PartialShape{ -1 });
+
+ auto pa_prim = paged_attention("paged_attention", { input_info("query"),
+ input_info("key"),
+ input_info("value"),
+ input_info("key_cache"),
+ input_info("value_cache"),
+ input_info("past_lens"),
+ input_info("subsequence_begins"),
+ input_info("block_indices"),
+ input_info("block_indices_begins"),
+ input_info("scale"),
+ input_info("sliding_window"),
+ input_info("alibi"),
+ input_info("max_context_len") });
+
+ pa_prim.head_size = p.head_size;
+ pa_prim.kv_heads_num = p.num_heads;
+ pa_prim.heads_num = p.num_heads;
+ pa_prim.scale_val = pam.get_default_scale();
+ pa_prim.has_alibi = false;
+ pa_prim.num_outputs = p.scores_output ? 2 : 1;
+
+ topology topology;
+ topology.add(
+ input_layout("query", query_layout),
+ input_layout("key", key_layout),
+ input_layout("value", value_layout),
+ input_layout("key_cache", key_cache_layout),
+ input_layout("value_cache", value_cache_layout),
+ input_layout("past_lens", past_lens_layout),
+ input_layout("subsequence_begins", subsequence_begins_layout),
+ input_layout("block_indices", block_indices_layout),
+ input_layout("block_indices_begins", block_indices_begins_layout),
+ input_layout("scale", scale_layout),
+ input_layout("sliding_window", sliding_window_layout),
+ input_layout("alibi", alibi_layout),
+ input_layout("max_context_len", max_context_len_layout),
+ pa_prim,
+ reorder("output_data", input_info("paged_attention", 0), format::bfyx, data_types::f16)
+ );
+
+ if (p.scores_output) {
+ topology.add(reorder("output_scores", input_info("paged_attention", 1), format::bfyx, data_types::f16));
+ }
+
+ ExecutionConfig config = get_test_default_config(get_test_engine());
+ config.set_property(ov::intel_gpu::optimize_data(true));
+ config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
+
+ network::ptr network = get_network(get_test_engine(), topology, config, get_test_stream_ptr(), false);
+ network->set_input_data("query", query_mem);
+ network->set_input_data("key", key_mem);
+ network->set_input_data("value", value_mem);
+ network->set_input_data("key_cache", key_cache_mem);
+ network->set_input_data("value_cache", value_cache_mem);
+ network->set_input_data("past_lens", past_lens_mem);
+ network->set_input_data("subsequence_begins", subsequence_begins_mem);
+ network->set_input_data("block_indices", block_indices_mem);
+ network->set_input_data("block_indices_begins", block_indices_begins_mem);
+ network->set_input_data("scale", scale_mem);
+ network->set_input_data("sliding_window", sliding_window_mem);
+ network->set_input_data("alibi", alibi_mem);
+ network->set_input_data("max_context_len", max_context_len_mem);
+
+ auto outputs = network->execute();
+
+ cldnn::memory::ptr output_data_mem = nullptr;
+ cldnn::memory::ptr output_scores_mem = nullptr;
+
+ output_data_mem = outputs.at("output_data").get_memory();
+ if (p.scores_output) {
+ output_scores_mem = outputs.at("output_scores").get_memory();
+ }
+
+ auto ref_data = PagedAttentionReference(pam).get_reference();
+ compare(output_data_mem, output_scores_mem, ref_data);
+ }
+
+ void compare(memory::ptr data_output_mem, memory::ptr scores_output_mem, std::pair, std::vector> ref_data) {
+ if (data_output_mem) {
+ ASSERT_EQ(data_output_mem->count(), ref_data.first.size());
+ mem_lock mem_ptr(data_output_mem, get_test_stream());
+ for (size_t i = 0; i < data_output_mem->count(); i++) {
+ ASSERT_NEAR(mem_ptr[i], ref_data.first[i], tolerance);
+ }
+ }
+
+ if (scores_output_mem) {
+ ASSERT_EQ(scores_output_mem->count(), ref_data.second.size());
+ mem_lock mem_ptr(scores_output_mem, get_test_stream());
+ for (size_t i = 0; i < scores_output_mem->count(); i++) {
+ ASSERT_NEAR(mem_ptr[i], ref_data.second[i], tolerance);
+ }
+ }
+ }
+};
+
+struct paged_attention_test_params {
+ std::vector subsequences;
+ int num_heads;
+ int head_size;
+ int block_size;
+ bool scores_output;
+};
+
+class paged_attention_test : public PagedAttentionTest {};
+TEST_P(paged_attention_test, basic) {
+ auto p = GetParam();
+
+ execute(p);
+}
+
+INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing::ValuesIn(std::vector{
+ /* with scores output */
+ paged_attention_test_params{ {{10, 0}}, 2, 64, 16, true }, // 1st token
+ paged_attention_test_params{ {{36, 0}}, 2, 64, 16, true }, // 1st token
+ paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, true }, // 1st token long
+ paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 64, 16, true }, // 1st token + 1st token
+ paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 64, 16, true }, // 1st token + 1st token
+ paged_attention_test_params{ {{1, 10}}, 2, 64, 16, true }, // 2nd token
+ paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, true }, // 2nd token + 2nd token
+ paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, true }, // mixed: 2nd token + 1st token + part of 1st token
+ /* without scores output */
+ paged_attention_test_params{ {{10, 0}}, 2, 64, 16, false }, // 1st token
+ paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, false }, // 1st token long
+ paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, false }, // 2nd token + 2nd token
+}));