From e2ac5352376b32fe7d949b43f362b27c3c56f5c4 Mon Sep 17 00:00:00 2001 From: Sergey Shlyapnikov Date: Fri, 27 Dec 2024 19:04:59 +0400 Subject: [PATCH] [GPU] Add scores output support for PagedAttention (#28205) ### Details: - Added scores output support for PagedAttention - Added PagedAttention unit tests ### Tickets: - [CVS-153660](https://jira.devtools.intel.com/browse/CVS-153660) --- .../intel_gpu/primitives/paged_attention.hpp | 4 + .../src/graph/impls/ocl/paged_attention.cpp | 279 +++++-- .../src/graph/include/paged_attention_inst.h | 12 +- .../intel_gpu/src/graph/paged_attention.cpp | 87 ++- .../kernel_selector/cl_kernels/pa_sdpa_opt.cl | 182 +++++ .../kernel_selector/cl_kernels/sdpa_opt.cl | 41 ++ .../sdpa/pa_kv_cache_update_kernel_ref.cpp | 2 +- .../kernels/sdpa/pa_sdpa_kernel_opt.cpp | 180 ++++- .../kernels/sdpa/pa_sdpa_kernel_opt.h | 10 +- .../kernels/sdpa/sdpa_kernel_base.h | 1 + .../kernels/sdpa/sdpa_kernel_opt.cpp | 96 ++- .../kernels/sdpa/sdpa_kernel_opt.h | 3 + .../src/plugin/ops/paged_attention.cpp | 5 +- .../test_cases/paged_attention_gpu_test.cpp | 687 ++++++++++++++++++ 14 files changed, 1428 insertions(+), 161 deletions(-) create mode 100644 src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp index f87f608597a6bb..2638f2ad60cf26 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp @@ -24,6 +24,10 @@ struct paged_attention : public primitive_base { OPENVINO_ASSERT(inputs.size() == 13, "[GPU] Unexpected inputs number for PagedAttention primitive: ", inputs.size()); } + bool has_scores_output() const { + return num_outputs == 2; + } + bool operator==(const primitive& rhs) const override { return compare_common_params(rhs); } diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp index 9cf1a252564934..2bc377f2c1459a 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp @@ -63,6 +63,7 @@ struct paged_attention_impl : multi_stage_primitive { void load(BinaryInputBuffer& ib) override { parent::load(ib); + ib >> make_data(&has_scores_output, sizeof(bool)); if (is_dynamic()) { auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance(); auto kv_cache_update_kernel_impl = kv_cache_update_kernel_selector.GetImplementation(_kernels_data[Stage::KV_CACHE_UPDATE].kernelName); @@ -78,7 +79,45 @@ struct paged_attention_impl : multi_stage_primitive { } } + void save(BinaryOutputBuffer& ob) const override { + parent::save(ob); + ob << make_data(&has_scores_output, sizeof(bool)); + } + std::vector get_internal_buffer_layouts_impl() const override { + /* + * Internal buffers allocation owners and users: + * +--------------------------------------+--------------------+--------------------+ + * | Stage | Allocates & uses | Reuses | + * +--------------------------------------+--------------------+--------------------+ + * | KV_CACHE_UPDATE | [0, 1, 2] | | + * +--------------------------------------+--------------------+--------------------+ + * | SDPA (1st token) | | [0, 1, 2] | + * +--------------------------------------+--------------------+--------------------+ + * | PA_SDPA (2nd+ token) | [5, 6, 7] | | + * +--------------------------------------+--------------------+--------------------+ + * | PA_SDPA (mixed mode) | [5, 6, 7, 8] | | + * +--------------------------------------+--------------------+--------------------+ + * | SDPA (1st token) + scores output | | [0, 1, 2, 3, 4] | + * +--------------------------------------+--------------------+--------------------+ + * | PA_SDPA (2nd+ token) + scores output | [3, 4, 5, 6, 7] | | + * +--------------------------------------+--------------------+--------------------+ + * | PA_SDPA (mixed mode) + scores output | [3, 4, 5, 6, 7, 8] | | + * +--------------------------------------+--------------------+--------------------+ + * + * Description: + * 0, 1, 2 - Buffers used for proper blocks distribution for kv_cache_update and + * sdpa_opt (1st token calculation) block configuration over target_seq_len dimension. + * Filled in paged_attention_inst::on_execute() call. + * 3, 4 - Optional buffers used for PA scores output calculation, storing intermediate + * softmax values by partitions (filled in PA/SDPA kernels) and sequence length offsets + * for each subsequence (filled in paged_attention_inst::on_execute() call). + * 5, 6, 7 - Used for 2nd+ PA calculation (for softmax exp_sums, max_logits, and intermediate output). + * Filled in PA/SDPA kernels. + * 8 - Optional buffer used for mixed PA execution mode, mapping gws idx to subsequence id. + * Filled in paged_attention_inst::on_execute() call. + */ + auto add_internal_buffers = [](std::vector& layouts, const kernel_selector::KernelData& kd) { if (kd.internalBufferSizes.empty()) return; @@ -133,6 +172,7 @@ struct paged_attention_impl : multi_stage_primitive { args.outputs = { instance.output_memory_ptr(0) }; } else if (stage == Stage::PA_SDPA) { if (kernel_idx == 0 || kernel_idx == 1) { + // 2nd+ token calculation or mixed stage tokens calculation args.shape_info = instance.shape_info_memory_ptr(); args.inputs = { instance.input_memory_ptr(0), @@ -155,7 +195,8 @@ struct paged_attention_impl : multi_stage_primitive { if (desc->has_alibi) { args.inputs.push_back(instance.alibi_memory_ptr()); } - } else { + } else if (kernel_idx == 2 || kernel_idx == 3) { + // Finalization kernel or mixed stage finalization kernel args.inputs = { instance.past_lens_memory_ptr() }; if (is_mixed_mode) { @@ -163,17 +204,31 @@ struct paged_attention_impl : multi_stage_primitive { // dependency args.inputs.push_back(instance.subsequence_begins_memory_ptr()); } + } else if (kernel_idx == 4) { + // Output scores calculation kernel + args.inputs = { instance.past_lens_memory_ptr(), + instance.subsequence_begins_memory_ptr() }; } args.outputs = { instance.output_memory_ptr(0) }; + + if (kernel_idx == 4) { + args.outputs.push_back(instance.output_memory_ptr(1)); + } } return args; } std::set get_lockable_internal_buffers() const override { - return std::set{ 0, 1, 2, /* SDPA and KV_CACHE_UPDATE indexes configuration */ - 6, /* PA_SDPA multiple tokens mode */ }; + size_t mixed_mode_buffer = has_scores_output ? 8 : 6; + + std::set lockable_ids = { 0, 1, 2, /* SDPA and KV_CACHE_UPDATE indexes configuration */ + mixed_mode_buffer /* PA_SDPA multiple tokens mode */ }; + if (has_scores_output) + lockable_ids.insert(4 /* Precalculated accumulated sequence length offsets for each subsequence */); + + return lockable_ids; }; void execute_stage(const std::vector& events, @@ -194,8 +249,17 @@ struct paged_attention_impl : multi_stage_primitive { if (stage == Stage::PA_SDPA) { internal_buffers_offset = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes.size(); internal_buffers_count = _kernels_data[Stage::PA_SDPA].internalBufferSizes.size(); - } else { + } else if (stage == Stage::KV_CACHE_UPDATE) { + internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes.size(); + } else if (stage == Stage::SDPA) { internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes.size(); + + const auto desc = instance.get_node().as().get_primitive(); + if (desc->has_scores_output()) { + // Add intermediate buffers for PagedAttention scores calculation: + // softmax_results, subsequence_offsets, exp_sums, max_logits, tmp_out + internal_buffers_count += 5; + } } for (size_t kd_idx = 0; kd_idx < _kernels_data[stage].kernels.size(); ++kd_idx) { @@ -216,6 +280,23 @@ struct paged_attention_impl : multi_stage_primitive { intermediate_memories.begin() + internal_buffers_offset, intermediate_memories.begin() + internal_buffers_offset + internal_buffers_count); + GPU_DEBUG_TRACE_DETAIL << "Execute stage=" << stage << " kernel=" << kd_idx << " " << _kernels_data[stage].kernelName << " start_offset=" + << internal_buffers_offset << " count=" << internal_buffers_count << "\n"; + + GPU_DEBUG_TRACE_DETAIL << "Configured kernel arguments:\n"; + for (size_t i = 0; i < _kernels_data[stage].kernels[kd_idx].params.arguments.size(); i++) { + GPU_DEBUG_TRACE_DETAIL << "\t" << i << ": type=" << static_cast(_kernels_data[stage].kernels[kd_idx].params.arguments[i].t) << " " + << "index=" << _kernels_data[stage].kernels[kd_idx].params.arguments[i].index << "\n"; + } + + GPU_DEBUG_TRACE_DETAIL << "Memory buffers:" + << "shape_info=" << args.shape_info << " " + << "inputs=" << args.inputs.size() << " " + << "outputs=" << args.outputs.size() << " " + << "intermediates=" << args.intermediates.size() << " " + << "weights=" << args.weights << " " + << "scalars=" << (args.scalars ? args.scalars->size() : 0) << "\n"; + stream.set_arguments(*_kernels[idx_final], _kernels_data[stage].kernels[kd_idx].params, args); const auto& gws = params.workGroups.global; @@ -242,10 +323,13 @@ struct paged_attention_impl : multi_stage_primitive { execute_stage(events, instance, res_events, Stage::KV_CACHE_UPDATE, is_mixed_mode); - std::vector dep_events(res_events.begin(), res_events.end()); if (stage == PagedAttentionStage::PREFILL) { + std::vector dep_events(res_events.begin(), res_events.end()); execute_stage(dep_events, instance, res_events, Stage::SDPA, is_mixed_mode); - } else if (stage == PagedAttentionStage::GENERATE || stage == PagedAttentionStage::MIXED) { + } + + if (stage == PagedAttentionStage::GENERATE || stage == PagedAttentionStage::MIXED || has_scores_output) { + std::vector dep_events(res_events.begin(), res_events.end()); execute_stage(dep_events, instance, res_events, Stage::PA_SDPA, is_mixed_mode); } @@ -338,7 +422,7 @@ struct paged_attention_impl : multi_stage_primitive { return aligned_seq_len; } - static kernel_selector::sdpa_configuration get_sdpa_configuration(const kernel_impl_params& impl_param) { + static kernel_selector::sdpa_configuration get_sdpa_configuration(const kernel_impl_params& impl_param, bool is_dynamic = true) { kernel_selector::sdpa_configuration config; const auto desc = impl_param.typed_desc(); @@ -362,37 +446,45 @@ struct paged_attention_impl : multi_stage_primitive { config.group_size = desc->heads_num / desc->kv_heads_num; } + if (desc->has_scores_output() && !is_dynamic) { + const auto& input_mem = impl_param.memory_deps; + const auto max_context_len = input_mem.at(12); + mem_lock max_context_len_mem_lock(max_context_len, *impl_param.strm); + config.paged_attention_max_len = max_context_len_mem_lock[0]; + } + return config; } static kv_cache_update_kernel_params_t get_kv_cache_update_kernel_params(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, + const kernel_selector::MultiDataTensor& input_tensors, bool is_dynamic = false) { auto params = get_default_params(impl_param, is_dynamic); - const auto& key_layout = impl_param.get_input_layout(1); - const auto& value_layout = impl_param.get_input_layout(2); - const auto& key_cache_layout = impl_param.get_input_layout(3); - const auto& value_cache_layout = impl_param.get_input_layout(4); - const auto& past_lens_layout = impl_param.get_input_layout(5); - const auto& block_indices_layout = impl_param.get_input_layout(7); - const auto& block_indices_begins_layout = impl_param.get_input_layout(8); - const auto& subsequence_begins_layout = impl_param.get_input_layout(6); + const auto& key_tensor = input_tensors[1]; + const auto& value_tensor = input_tensors[2]; + const auto& key_cache_tensor = input_tensors[3]; + const auto& value_cache_tensor = input_tensors[4]; + const auto& past_lens_tensor = input_tensors[5]; + const auto& block_indices_tensor = input_tensors[7]; + const auto& block_indices_begins_tensor = input_tensors[8]; + const auto& subsequence_begins_tensor = input_tensors[6]; const auto inputs_number = 6; const auto outputs_number = 2; params.inputs.resize(inputs_number); params.outputs.resize(outputs_number); - params.inputs[0] = convert_data_tensor(key_layout); - params.inputs[1] = convert_data_tensor(value_layout); - params.inputs[2] = convert_data_tensor(past_lens_layout); - params.inputs[3] = convert_data_tensor(block_indices_layout); - params.inputs[4] = convert_data_tensor(block_indices_begins_layout); - params.inputs[5] = convert_data_tensor(subsequence_begins_layout); - params.outputs[0] = convert_data_tensor(key_cache_layout); - params.outputs[1] = convert_data_tensor(value_cache_layout); + params.inputs[0] = key_tensor; + params.inputs[1] = value_tensor; + params.inputs[2] = past_lens_tensor; + params.inputs[3] = block_indices_tensor; + params.inputs[4] = block_indices_begins_tensor; + params.inputs[5] = subsequence_begins_tensor; + params.outputs[0] = key_cache_tensor; + params.outputs[1] = value_cache_tensor; - params.conf = get_sdpa_configuration(impl_param); + params.conf = get_sdpa_configuration(impl_param, is_dynamic); params.is_prefill = stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED; @@ -418,18 +510,23 @@ struct paged_attention_impl : multi_stage_primitive { return params; } - static sdpa_kernel_params_t get_sdpa_kernel_params(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, bool is_dynamic = false) { + static sdpa_kernel_params_t get_sdpa_kernel_params(const kernel_impl_params& impl_param, + const PagedAttentionStage& stage, + const kernel_selector::MultiDataTensor& input_tensors, + bool is_dynamic = false) { const auto desc = impl_param.typed_desc(); auto params = get_default_params(impl_param, is_dynamic); - const auto& query_layout = impl_param.get_input_layout(0); - const auto& key_layout = impl_param.get_input_layout(1); - const auto& value_layout = impl_param.get_input_layout(2); - const auto& subsequence_begins_layout = impl_param.get_input_layout(6); - const auto& scale_layout = impl_param.get_input_layout(9); - const auto& alibi_layout = impl_param.get_input_layout(11); - const auto has_alibi = alibi_layout.count() > 0; + const auto& query_tensor = input_tensors[0]; + const auto& key_tensor = input_tensors[1]; + const auto& value_tensor = input_tensors[2]; + const auto& subsequence_begins_tensor = input_tensors[6]; + const auto& scale_tensor = input_tensors[9]; + const auto& alibi_tensor = input_tensors[11]; + + const auto has_alibi = impl_param.get_input_layout(11).count() > 0; const auto has_scale_input = !desc->scale_val.has_value(); + const auto has_scores_output = desc->has_scores_output(); auto inputs_number = 4; if (has_scale_input) @@ -440,18 +537,23 @@ struct paged_attention_impl : multi_stage_primitive { auto input_idx = 0; params.inputs.resize(inputs_number); - params.inputs[input_idx++] = convert_data_tensor(query_layout); - params.inputs[input_idx++] = convert_data_tensor(key_layout); - params.inputs[input_idx++] = convert_data_tensor(value_layout); - params.inputs[input_idx++] = convert_data_tensor(subsequence_begins_layout); + params.inputs[input_idx++] = query_tensor; + params.inputs[input_idx++] = key_tensor; + params.inputs[input_idx++] = value_tensor; + params.inputs[input_idx++] = subsequence_begins_tensor; if (has_scale_input) - params.inputs[input_idx++] = convert_data_tensor(scale_layout); + params.inputs[input_idx++] = scale_tensor; if (has_alibi) - params.inputs[input_idx++] = convert_data_tensor(alibi_layout); + params.inputs[input_idx++] = alibi_tensor; - params.conf = get_sdpa_configuration(impl_param); + if (has_scores_output) { + params.outputs.resize(2); + params.outputs[1] = convert_data_tensor(impl_param.get_output_layout(1)); + } + + params.conf = get_sdpa_configuration(impl_param, is_dynamic); const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset; @@ -475,26 +577,34 @@ struct paged_attention_impl : multi_stage_primitive { if ((stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED) && !is_dynamic) params.conf.paged_attention_aligned_seq_len = get_aligned_seq_len(impl_param, stage); + if (has_scores_output) + out_tensor_to_offset_map.insert({1, out_offsets_map.at(1)}); + params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map); return params; } - static pa_sdpa_kernel_params_t get_pa_sdpa_params(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, bool is_dynamic = false) { + static pa_sdpa_kernel_params_t get_pa_sdpa_params(const kernel_impl_params& impl_param, + const PagedAttentionStage& stage, + const kernel_selector::MultiDataTensor& input_tensors, + bool is_dynamic = false) { const auto desc = impl_param.typed_desc(); auto params = get_default_params(impl_param, is_dynamic); - const auto& query_layout = impl_param.get_input_layout(0); - const auto& key_cache_layout = impl_param.get_input_layout(3); - const auto& value_cache_layout = impl_param.get_input_layout(4); - const auto& past_lens_layout = impl_param.get_input_layout(5); - const auto& block_indices_layout = impl_param.get_input_layout(7); - const auto& block_indices_begins_layout = impl_param.get_input_layout(8); - const auto& subsequence_begins_layout = impl_param.get_input_layout(6); - const auto& scale_layout = impl_param.get_input_layout(9); - const auto& alibi_layout = impl_param.get_input_layout(11); - const auto has_alibi = alibi_layout.count() > 0; + const auto& query_tensor = input_tensors[0]; + const auto& key_cache_tensor = input_tensors[3]; + const auto& value_cache_tensor = input_tensors[4]; + const auto& past_lens_tensor = input_tensors[5]; + const auto& block_indices_tensor = input_tensors[7]; + const auto& block_indices_begins_tensor = input_tensors[8]; + const auto& subsequence_begins_tensor = input_tensors[6]; + const auto& scale_tensor = input_tensors[9]; + const auto& alibi_tensor = input_tensors[11]; + + const auto has_alibi = impl_param.get_input_layout(11).count() > 0; const auto has_scale_input = !desc->scale_val.has_value(); + const auto has_scores_output = desc->has_scores_output(); auto inputs_number = 7; if (has_scale_input) @@ -505,28 +615,34 @@ struct paged_attention_impl : multi_stage_primitive { auto input_idx = 0; params.inputs.resize(inputs_number); - params.inputs[input_idx++] = convert_data_tensor(query_layout); - params.inputs[input_idx++] = convert_data_tensor(key_cache_layout); - params.inputs[input_idx++] = convert_data_tensor(value_cache_layout); - params.inputs[input_idx++] = convert_data_tensor(past_lens_layout); - params.inputs[input_idx++] = convert_data_tensor(block_indices_layout); - params.inputs[input_idx++] = convert_data_tensor(block_indices_begins_layout); - params.inputs[input_idx++] = convert_data_tensor(subsequence_begins_layout); - params.conf = get_sdpa_configuration(impl_param); + params.inputs[input_idx++] = query_tensor; + params.inputs[input_idx++] = key_cache_tensor; + params.inputs[input_idx++] = value_cache_tensor; + params.inputs[input_idx++] = past_lens_tensor; + params.inputs[input_idx++] = block_indices_tensor; + params.inputs[input_idx++] = block_indices_begins_tensor; + params.inputs[input_idx++] = subsequence_begins_tensor; + + params.conf = get_sdpa_configuration(impl_param, is_dynamic); if (has_scale_input) - params.inputs[input_idx++] = convert_data_tensor(scale_layout); + params.inputs[input_idx++] = scale_tensor; if (has_alibi) - params.inputs[input_idx++] = convert_data_tensor(alibi_layout); + params.inputs[input_idx++] = alibi_tensor; - params.multi_tokens_mode = stage == PagedAttentionStage::MIXED; + if (has_scores_output) { + params.outputs.resize(2); + params.outputs[1] = convert_data_tensor(impl_param.get_output_layout(1)); + } - if ((stage == PagedAttentionStage::GENERATE || stage == PagedAttentionStage::MIXED) && !is_dynamic) { + params.stage = stage; + + if (!has_scores_output && !is_dynamic) { const auto& input_mem = impl_param.memory_deps; const auto max_context_len = input_mem.at(12); mem_lock max_context_len_mem_lock(max_context_len, *impl_param.strm); - params.max_context_len = max_context_len_mem_lock[0]; + params.conf.paged_attention_max_len = max_context_len_mem_lock[0]; } const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; @@ -552,6 +668,9 @@ struct paged_attention_impl : multi_stage_primitive { if (has_alibi) in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(11)}); + if (has_scores_output) + out_tensor_to_offset_map.insert({1, out_offsets_map.at(1)}); + params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map); return params; @@ -560,14 +679,20 @@ struct paged_attention_impl : multi_stage_primitive { void update_dispatch_data(const kernel_impl_params& impl_param) override { const auto stage = get_paged_attention_stage(impl_param); - auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, impl_param.is_dynamic()); + kernel_selector::MultiDataTensor input_tensors; + for (const auto& input_layout : impl_param.input_layouts) + input_tensors.emplace_back(convert_data_tensor(input_layout)); + + auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic()); (_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func)(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]); if (stage == PagedAttentionStage::PREFILL) { - auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, impl_param.is_dynamic()); + auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic()); (_kernels_data[Stage::SDPA].update_dispatch_data_func)(sdpa_kernel_params, _kernels_data[Stage::SDPA]); - } else if (stage == PagedAttentionStage::GENERATE || stage == PagedAttentionStage::MIXED) { - auto pa_sdpa_kernel_params = get_pa_sdpa_params(impl_param, stage, impl_param.is_dynamic()); + } + + if (stage == PagedAttentionStage::GENERATE || stage == PagedAttentionStage::MIXED || has_scores_output) { + auto pa_sdpa_kernel_params = get_pa_sdpa_params(impl_param, stage, input_tensors, impl_param.is_dynamic()); (_kernels_data[Stage::PA_SDPA].update_dispatch_data_func)(pa_sdpa_kernel_params, _kernels_data[Stage::PA_SDPA]); } } @@ -576,20 +701,32 @@ struct paged_attention_impl : multi_stage_primitive { std::vector kernels_data; const auto stage = PagedAttentionStage::UNKNOWN; - auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, impl_param.is_dynamic()); + kernel_selector::MultiDataTensor input_tensors; + for (const auto& input_layout : impl_param.input_layouts) + input_tensors.emplace_back(convert_data_tensor(input_layout)); + + auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic()); auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance(); kernels_data.push_back(kv_cache_update_kernel_selector.get_best_kernel(kv_cache_update_kernel_params)); - auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, impl_param.is_dynamic()); + auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic()); auto& sdpa_kernel_selector = sdpa_kernel_selector_t::Instance(); kernels_data.push_back(sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params)); - auto pa_sdpa_kernel_params = get_pa_sdpa_params(impl_param, stage, impl_param.is_dynamic()); + auto pa_sdpa_kernel_params = get_pa_sdpa_params(impl_param, stage, input_tensors, impl_param.is_dynamic()); auto& pa_sdpa_kernel_selector = pa_sdpa_kernel_selector_t::Instance(); kernels_data.push_back(pa_sdpa_kernel_selector.get_best_kernel(pa_sdpa_kernel_params)); - return cldnn::make_unique(kernels_data); + auto impl = cldnn::make_unique(kernels_data); + + const auto& desc = impl_param.typed_desc(); + impl->has_scores_output = desc->has_scores_output(); + + return impl; } + +private: + bool has_scores_output = false; }; namespace detail { diff --git a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h index a7918ba9c3719c..675d77296aa06b 100644 --- a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h @@ -7,14 +7,11 @@ #include "intel_gpu/primitives/paged_attention.hpp" #include "primitive_inst.h" +#include "sdpa/pa_sdpa_kernel_opt.h" + namespace cldnn { -enum PagedAttentionStage { - GENERATE = 0, - PREFILL = 1, - MIXED = 2, - UNKNOWN = 3 -}; +using PagedAttentionStage = kernel_selector::PagedAttentionStage; PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param); @@ -61,6 +58,9 @@ class typed_primitive_inst : public typed_primitive_inst_base

prefill_network; diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index 787fd184f75b6a..c761aaf63799cd 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -48,14 +48,38 @@ layout paged_attention_inst::calc_output_layout(const paged_attention_node& /*no template std::vector paged_attention_inst::calc_output_layouts(paged_attention_node const& /*node*/, kernel_impl_params const& impl_param) { - auto out_layout = impl_param.get_input_layout(0); + auto data_layout = impl_param.get_input_layout(0); const auto& key_cache_ps = impl_param.get_input_layout(3).get_partial_shape(); bool valid_block_size = key_cache_ps[3].is_dynamic() || key_cache_ps[3].get_length() == paged_attention::block_size; OPENVINO_ASSERT(valid_block_size, "[GPU] Incorrect block size for Paged Attention operation. " "Expected ", paged_attention::block_size, ", but got ", key_cache_ps[3].get_length()); - return {out_layout}; + std::vector output_layouts{ data_layout }; + + const auto& desc = impl_param.typed_desc(); + if (desc->has_scores_output()) { + const auto past_lens_idx = 5; + const auto output_dt = data_layout.data_type; + if (impl_param.get_input_layout(past_lens_idx).is_static()) { + const auto& memory_deps = impl_param.memory_deps; + const auto past_lens_mem = memory_deps.at(past_lens_idx); + mem_lock past_lens_mem_lock(past_lens_mem, *impl_param.strm); + + long int total_size = 0; + for (size_t i = 0; i < past_lens_mem_lock.size(); i++) { + total_size += past_lens_mem_lock[i]; + } + + total_size += static_cast(impl_param.get_input_layout(0).get_shape()[0]); + + output_layouts.push_back(layout{ov::PartialShape{total_size}, output_dt, format::bfyx}); + } else { + output_layouts.push_back(layout{ov::PartialShape::dynamic(1), output_dt, format::bfyx}); + } + } + + return output_layouts; } template std::vector @@ -81,45 +105,79 @@ std::string paged_attention_inst::to_string(const paged_attention_node& node) { } void paged_attention_inst::on_execute() { - auto stage = get_paged_attention_stage(*_impl_params); + const auto& desc = _impl_params->typed_desc(); + const bool has_scores_output = desc->has_scores_output(); + const auto stage = get_paged_attention_stage(*_impl_params); - if (stage == PagedAttentionStage::UNKNOWN || - stage == PagedAttentionStage::GENERATE) + if ((stage == PagedAttentionStage::UNKNOWN) || + (stage == PagedAttentionStage::GENERATE && !has_scores_output)) return; + auto& stream = get_network().get_stream(); + const auto past_lens_mem = past_lens_memory_ptr(); + const auto subsequence_begins_mem = subsequence_begins_memory_ptr(); + mem_lock past_lens_mem_lock(past_lens_mem, stream); + mem_lock subsequence_begins_mem_lock(subsequence_begins_mem, stream); + std::unique_ptr> subsequence_offsets_lock = nullptr; + + if (has_scores_output) { + const size_t subsequence_offsets_idx = 4; + + OPENVINO_ASSERT(_intermediates_memory.size() > subsequence_offsets_idx, + "[GPU] Unexpected number of intermediates buffers for Paged Attention for scores output calculation"); + + auto subsequence_offsets_mem = _intermediates_memory[subsequence_offsets_idx]; + subsequence_offsets_lock.reset(new mem_lock(subsequence_offsets_mem, stream)); + } + + if (stage == PagedAttentionStage::GENERATE) { + // For the generate stage it's not necessary to configure any other intermediate + // buffers. Simply calculate the offsets and exit + size_t subsequence_offsets_acc = 0; + for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) { + const auto past_len = past_lens_mem_lock[i]; + const auto seq_start = subsequence_begins_mem_lock[i]; + const auto seq_end = subsequence_begins_mem_lock[i + 1]; + const auto seq_length = seq_end - seq_start; + + if (subsequence_offsets_lock) { + subsequence_offsets_lock->operator[](i) = static_cast(subsequence_offsets_acc); + subsequence_offsets_acc += seq_length + past_len; + } + } + + return; + } + OPENVINO_ASSERT(_intermediates_memory.size() >= 3, "Unexpected number of intermediates buffers for Paged Attention at prefill stage"); const auto blocks_indexes_start_idx = 0; const auto blocks_indexes_end_idx = 1; const auto blocked_gws_subseq_mapping_idx = 2; - const auto past_lens_mem = past_lens_memory_ptr(); - auto subsequence_begins_mem = subsequence_begins_memory_ptr(); auto blocks_indexes_start_mem = _intermediates_memory[blocks_indexes_start_idx]; auto blocks_indexes_end_mem = _intermediates_memory[blocks_indexes_end_idx]; auto blocked_gws_subseq_mapping_mem = _intermediates_memory[blocked_gws_subseq_mapping_idx]; OPENVINO_ASSERT(subsequence_begins_mem->get_layout().data_type == data_types::i32); - auto& stream = get_network().get_stream(); - mem_lock past_lens_mem_lock(past_lens_mem, stream); - mem_lock subsequence_begins_mem_lock(subsequence_begins_mem, stream); mem_lock blocks_indexes_start_lock(blocks_indexes_start_mem, stream); mem_lock blocks_indexes_end_lock(blocks_indexes_end_mem, stream); mem_lock blocked_gws_subseq_mapping_mem_lock(blocked_gws_subseq_mapping_mem, stream); std::unique_ptr> sequential_gws_subseq_mapping_lock = nullptr; if (stage == PagedAttentionStage::MIXED) { - const auto sequential_gws_subseq_mapping_idx = 6; + const size_t sequential_gws_subseq_mapping_idx = has_scores_output ? 8 : 6; OPENVINO_ASSERT(_intermediates_memory.size() > sequential_gws_subseq_mapping_idx, - "Unexpected number of intermediates buffers for Paged Attention for mixed stage"); + "[GPU] Unexpected number of intermediates buffers for Paged Attention for mixed stage"); auto sequential_gws_subseq_mapping_mem = _intermediates_memory[sequential_gws_subseq_mapping_idx]; sequential_gws_subseq_mapping_lock.reset(new mem_lock(sequential_gws_subseq_mapping_mem, stream)); } size_t index = 0; + size_t subsequence_offsets_acc = 0; const auto target_seq_len_block_size = 16; // TODO: Get block size from the impl for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) { const auto past_len = past_lens_mem_lock[i]; @@ -159,6 +217,11 @@ void paged_attention_inst::on_execute() { sequential_gws_subseq_mapping_lock->operator[](idx) = static_cast(i); } } + + if (subsequence_offsets_lock) { + subsequence_offsets_lock->operator[](i) = static_cast(subsequence_offsets_acc); + subsequence_offsets_acc += seq_length + past_len; + } } } diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl index 00c43829d02ea7..7e960afa4b87d3 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl @@ -44,6 +44,10 @@ KERNEL(pa_sdpa_opt)( const __global ALIBI_INPUT_TYPE* alibi_slopes, #endif __global OUTPUT_TYPE* output, +#if PAGED_ATTENTION_SCORES_OUTPUT + __global SOFTMAX_ACCUMULATOR_TYPE* softmax_results, + const __global int* subsequence_offsets, +#endif __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums, __global SOFTMAX_ACCUMULATOR_TYPE* max_logits, __global OUTPUT_TYPE* tmp_out @@ -276,6 +280,28 @@ KERNEL(pa_sdpa_opt)( const uint max_logits_offset = exp_sums_offset; max_logits[max_logits_offset] = qk_max; } + +#if PAGED_ATTENTION_SCORES_OUTPUT +#if MULTI_TOKENS_PROCESSING + const uint subsequence_idx = gws_subseq_mapping[seq_idx]; + const uint subsequence_start_pos = subsequence_begins[subsequence_idx]; + const uint subsequence_end_pos = subsequence_begins[subsequence_idx + 1]; + const bool save_softmax_results = seq_idx == subsequence_end_pos - 1; +#else + const uint subsequence_idx = seq_idx; + const bool save_softmax_results = true; +#endif // MULTI_TOKENS_PROCESSING + // PagedAttention is supposed to save only last "row" of the QK matrix multiplication, + // so save SEQ_LEN_PARTITION_SIZE elements for each partition + if (save_softmax_results) { + const uint output_offset = subsequence_idx * HEADS_NUM * total_partitions_num * SEQ_LEN_PARTITION_SIZE + + head_num_idx * total_partitions_num * SEQ_LEN_PARTITION_SIZE + + partition_idx * SEQ_LEN_PARTITION_SIZE; + for (uint i = sgid * SUBGROUP_SIZE + sglid; i < SEQ_LEN_PARTITION_SIZE; i += SUBGROUPS_PER_WG * SUBGROUP_SIZE) { + softmax_results[output_offset + i] = slm_qk_vals[i]; + } + } +#endif // PAGED_ATTENTION_SCORES_OUTPUT } } @@ -370,6 +396,10 @@ KERNEL(pa_sdpa_finalization_stage)( const __global INPUT6_TYPE* subsequence_begins, #endif __global OUTPUT_TYPE* output, +#if PAGED_ATTENTION_SCORES_OUTPUT + __global SOFTMAX_ACCUMULATOR_TYPE* softmax_results, + const __global int* subsequence_offsets, +#endif const __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums, const __global SOFTMAX_ACCUMULATOR_TYPE* max_logits, const __global OUTPUT_TYPE* tmp_out, @@ -500,3 +530,155 @@ KERNEL(pa_sdpa_finalization_stage)( } #endif + +#ifdef SDPA_STAGE_2 +#define MAX_PARTITIONS_NUM 128 + +REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE) +KERNEL(pa_sdpa_scores_calculation)( + const __global INPUT3_TYPE* past_lens, + const __global INPUT6_TYPE* subsequence_begins, + __global OUTPUT1_TYPE* scores_output, + const __global SOFTMAX_ACCUMULATOR_TYPE* softmax_output, + const __global int* subsequence_offsets, + const __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums, + const __global SOFTMAX_ACCUMULATOR_TYPE* max_logits, + const __global OUTPUT_TYPE* tmp_out, + const uint is_mixed_mode) { + const uint subsequence_idx = get_global_id(2); + const uint partition_global_idx = get_global_id(0); + const uint local_id = get_local_id(0); + const uint partition_idx = get_group_id(0); + const uint partition_size = get_local_size(0); + const uint max_seq_len = get_global_size(0); + const uint partitions_num = get_num_groups(0); + const uint sgid = get_sub_group_id(); + const uint sgid_num = get_num_sub_groups(); + const uint sglid = get_sub_group_local_id(); + + const int subsequence_begin = subsequence_begins[subsequence_idx]; + const int subsequence_end = subsequence_begins[subsequence_idx + 1]; + const uint seq_len = (subsequence_end - subsequence_begin) + past_lens[subsequence_idx]; + + const uint num_of_partitions = CEIL_DIV(seq_len, partition_size); + + if (partition_idx >= num_of_partitions) + return; + + __local SOFTMAX_ACCUMULATOR_TYPE slm_exp_sums[HEADS_NUM]; + __local SOFTMAX_ACCUMULATOR_TYPE slm_global_exp_sum[HEADS_NUM]; + + SOFTMAX_ACCUMULATOR_TYPE total_score = SOFTMAX_ACCUMULATOR_VAL_ZERO; + if (seq_len <= partition_size) { + // If seq_len is less than the partition size, just reduce the results over the heads + for (uint head_idx = 0; head_idx < HEADS_NUM; head_idx++) { + const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx; + SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output[input_offset]; + total_score += softmax_value; + } + } else if (seq_len <= partition_size * MAX_PARTITIONS_NUM) { + // Optimized version for longer prompts (up to partition_size * MAX_PARTITIONS_NUM, ~64K tokens) + + // Depending on the previous kernel exp_sums and max_logits might have different structure: + // For ordinary 1st and 2nd token kernels, there is only a single entry per subsequence. + // However, for mixed mode execution, exp_sums and max_logits include information for all + // tokens of each subsequence, but only the last one is needed for score calculation. + const uint subsequence_pos = is_mixed_mode ? subsequence_end - 1 : subsequence_idx; + + for (uint head_idx = sgid; head_idx < HEADS_NUM; head_idx += sgid_num) { + SOFTMAX_ACCUMULATOR_TYPE max_logit[MAX_PARTITIONS_NUM / SUBGROUP_SIZE]; + SOFTMAX_ACCUMULATOR_TYPE exp_sum[MAX_PARTITIONS_NUM / SUBGROUP_SIZE]; + + const uint exp_sums_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num; + for (int i = 0; i < partitions_num / SUBGROUP_SIZE; i++) { + max_logit[i] = max_logits[exp_sums_offset + i * SUBGROUP_SIZE + sglid]; + exp_sum[i] = exp_sums[exp_sums_offset + i * SUBGROUP_SIZE + sglid]; + } + + const uint partitions_leftovers = partitions_num % SUBGROUP_SIZE; + if (partitions_leftovers != 0) { + const uint idx = partitions_num / SUBGROUP_SIZE; + max_logit[idx] = sglid >= partitions_leftovers ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits[exp_sums_offset + idx * SUBGROUP_SIZE + sglid]; + exp_sum[idx] = sglid >= partitions_leftovers ? SOFTMAX_ACCUMULATOR_VAL_ZERO : exp_sums[exp_sums_offset + idx * SUBGROUP_SIZE + sglid]; + } + + SOFTMAX_ACCUMULATOR_TYPE global_max_logit = max_logit[0]; + for (uint i = 1; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) { + global_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(global_max_logit, max_logit[i]); + } + + global_max_logit = sub_group_reduce_max(global_max_logit); + + SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO; + for (uint i = 0; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) { + SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = exp_sum[i] * native_exp(max_logit[i] - global_max_logit); + // slm_exp_sums[head_idx][i * SUBGROUP_SIZE + sglid] = adjusted_exp_sum; + if (i * SUBGROUP_SIZE + sglid == partition_idx) + slm_exp_sums[head_idx] = adjusted_exp_sum; + global_exp_sum += adjusted_exp_sum; + } + + global_exp_sum = sub_group_reduce_add(global_exp_sum); + + slm_global_exp_sum[head_idx] = global_exp_sum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + for (uint head_idx = 0; head_idx < HEADS_NUM; head_idx++) { + SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = slm_exp_sums[head_idx]; + SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = slm_global_exp_sum[head_idx]; + + const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx; + SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output[input_offset]; + + softmax_value = softmax_value * adjusted_exp_sum / global_exp_sum; + total_score += softmax_value; + } + } else { + // Non optimized fallback version + const uint subsequence_pos = is_mixed_mode ? subsequence_end - 1 : subsequence_idx; + for (uint head_idx = 0; head_idx < HEADS_NUM; head_idx++) { + SOFTMAX_ACCUMULATOR_TYPE global_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN; + const uint max_logits_base_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num; + for (uint i = 0; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) { + const uint partition_offset = i * SUBGROUP_SIZE + sglid; + SOFTMAX_ACCUMULATOR_TYPE max_logit = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits[max_logits_base_offset + partition_offset]; + global_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(global_max_logit, max_logit); + } + + global_max_logit = sub_group_reduce_max(global_max_logit); + + SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO; + SOFTMAX_ACCUMULATOR_TYPE partition_adjusted_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO; + const uint exp_sums_base_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num; + for (uint i = 0; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) { + const uint partition_offset = i * SUBGROUP_SIZE + sglid; + SOFTMAX_ACCUMULATOR_TYPE exp_sum = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_ZERO : exp_sums[exp_sums_base_offset + partition_offset]; + SOFTMAX_ACCUMULATOR_TYPE max_logit = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits[max_logits_base_offset + partition_offset]; + SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = exp_sum * native_exp(max_logit - global_max_logit); + global_exp_sum += adjusted_exp_sum; + + // Save and broadcast the adjusted exp_sum for the currently being processed partition + if (i == partition_idx / SUBGROUP_SIZE) + partition_adjusted_exp_sum = sub_group_broadcast(adjusted_exp_sum, partition_idx % SUBGROUP_SIZE); + } + + global_exp_sum = sub_group_reduce_add(global_exp_sum); + + const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx; + SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output[input_offset]; + + softmax_value = softmax_value * partition_adjusted_exp_sum / global_exp_sum; + total_score += softmax_value; + } + } + + const uint output_offset = subsequence_offsets[subsequence_idx]; + if (partition_global_idx < seq_len) { + scores_output[output_offset + partition_global_idx] = total_score; + } +} + +#undef MAX_PARTITIONS_NUM +#endif diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl index 55f87e4189d9fe..cddafe62623d9e 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl @@ -842,6 +842,14 @@ KERNEL(sdpa_opt)( const __global int* blocked_indexes_start, const __global int* blocked_indexes_end, const __global int* gws_seq_indexes_correspondence +#if PAGED_ATTENTION_SCORES_OUTPUT + , __global SOFTMAX_ACCUMULATOR_TYPE* softmax_results + , const __global int* subsequence_offsets + , __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums + , __global SOFTMAX_ACCUMULATOR_TYPE* max_logits + , __global OUTPUT_TYPE* tmp_out + , const uint aligned_max_context_len +#endif #else __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums, __global SOFTMAX_ACCUMULATOR_TYPE* max_logits, @@ -1222,6 +1230,39 @@ KERNEL(sdpa_opt)( slm_qk_vals[sglid * SEQ_LEN_PARTITION_SIZE + sgid * TARGET_SEQ_LEN_BLOCK_SIZE + i] = qk_acc[i]; } +#if PAGED_ATTENTION_SCORES_OUTPUT + const uint subsequence_idx = gws_seq_indexes_correspondence[target_seq_dim]; + const uint subsequence_end_pos = subsequence_begins[subsequence_idx + 1]; + const uint block_start_pos = blocked_indexes_start[target_seq_dim]; + const uint block_end_pos = blocked_indexes_end[target_seq_dim]; + + // PagedAttention is supposed to save only last "row" of the QK matrix multiplication, + // so save SEQ_LEN_PARTITION_SIZE elements for each partition + if (subsequence_end_pos == block_end_pos) { + const uint last_row_idx = block_end_pos - block_start_pos - 1; + if (sglid == last_row_idx) { + const uint partition_idx = start_partition_idx / SEQ_LEN_PARTITION_SIZE; + + if (sgid == 0) { + const uint max_partitions_num = aligned_max_context_len / SEQ_LEN_PARTITION_SIZE; + const uint exp_sums_output_offset = subsequence_idx * NUM_HEADS * max_partitions_num + + num_heads_dim * max_partitions_num + + partition_idx; + exp_sums[exp_sums_output_offset] = exp_sum_new; + max_logits[exp_sums_output_offset] = qk_max_new; + } + + const uint output_offset = subsequence_idx * NUM_HEADS * aligned_max_context_len + + num_heads_dim * aligned_max_context_len + + partition_idx * SEQ_LEN_PARTITION_SIZE + sgid * TARGET_SEQ_LEN_BLOCK_SIZE; + for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) { + softmax_results[output_offset + i] = qk_acc[i]; + } + + } + } +#endif + barrier(CLK_LOCAL_MEM_FENCE); } diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_update_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_update_kernel_ref.cpp index ddfb491f50278a..ce20f49de597ff 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_update_kernel_ref.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_update_kernel_ref.cpp @@ -167,7 +167,7 @@ void KVCacheUpdateKernelRef::GetUpdateDispatchDataFunc(KernelData& kd) const { const auto indexes_dt = Datatype::INT32; const auto target_seq_len_block_size = 16; - const auto target_seq_len = prim_params.conf.paged_attention_aligned_seq_len; + const auto target_seq_len = std::max(prim_params.conf.paged_attention_aligned_seq_len, static_cast(1)); const auto indexes_buf_size = CeilDiv(target_seq_len, target_seq_len_block_size) * BytesPerElement(indexes_dt); kd.internalBufferSizes.clear(); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp index 63c5e74160f652..909a40d677f535 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "sdpa_kernel_opt.h" #include "pa_sdpa_kernel_opt.h" #include "kernel_selector_params.h" @@ -15,6 +16,7 @@ enum KernelsTypes { MULTI_TOKENS, FINALIZATION, FINALIZATION_MULTI_TOKENS, + SCORES_CALCULATION, TOTAL_KERNELS_NUM }; @@ -35,6 +37,8 @@ static std::string GetKernelName(std::string base_name, KernelsTypes type) { kernel_name += "_finalization"; } else if (type == KernelsTypes::FINALIZATION_MULTI_TOKENS) { kernel_name += "_finalization_multi_tokens_seq"; + } else if (type == KernelsTypes::SCORES_CALCULATION) { + kernel_name += "_scores_calculation"; } return kernel_name; @@ -46,10 +50,15 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const { } const auto& params = static_cast(p); - const std::vector kernels_type = { KernelsTypes::SINGLE_TOKEN, - KernelsTypes::MULTI_TOKENS, - KernelsTypes::FINALIZATION, - KernelsTypes::FINALIZATION_MULTI_TOKENS }; + std::vector kernels_type = { KernelsTypes::SINGLE_TOKEN, + KernelsTypes::MULTI_TOKENS, + KernelsTypes::FINALIZATION, + KernelsTypes::FINALIZATION_MULTI_TOKENS }; + + const auto has_scores_output = params.outputs.size() > 1; + if (has_scores_output) { + kernels_type.push_back(KernelsTypes::SCORES_CALCULATION); + } KernelData kd = KernelData::Default(params, kernels_type.size()); kd.needs_sub_kernels_sync = true; @@ -65,7 +74,8 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const { const auto jit = CreateJit(kernel_name, jit_constants, entry_point); - size_t inputs_num = static_cast(params.inputs.size()); + int inputs_num = static_cast(params.inputs.size()); + int outputs_num = 1; if (kernel_type == KernelsTypes::SINGLE_TOKEN) { // SINGLE_TOKEN kernel doesn't use the subsequence_begins input inputs_num -= 1; @@ -75,6 +85,11 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const { } else if (kernel_type == KernelsTypes::FINALIZATION_MULTI_TOKENS) { // FINALIZATION_MULTI_TOKENS kernel uses past_lens data input and subsequence_begins inputs_num = 2; + } else if (kernel_type == KernelsTypes::SCORES_CALCULATION) { + // SCORES_CALCULATION kernel uses past_lens data input and subsequence_begins + inputs_num = 2; + // Output is configured manually to use the second output memory buffer + outputs_num = 0; } auto& kernel = kd.kernels[kd_kernels_idx++]; @@ -87,19 +102,33 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const { {}, false, false, - static_cast(inputs_num), + inputs_num, GetFusedPrimitiveInputsCount(params), - static_cast(params.outputs.size()), + outputs_num, params.is_shape_agnostic); - kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); - kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); - kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); + if (kernel_type == KernelsTypes::SCORES_CALCULATION) { + kernel.params.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 1}); + } + + uint32_t internal_buffers_num = 0; + if (has_scores_output) { + // Intermediate softmax results for scores output calculation and precalculated accumulated + // sequence length offsets for each subsequence + internal_buffers_num += 2; + } + + // Softmax's exp_sums, max_logits and intermediate output + internal_buffers_num += 3; if (kernel_type == KernelsTypes::MULTI_TOKENS || kernel_type == KernelsTypes::FINALIZATION_MULTI_TOKENS) { // MULTIPLE_TOKENS kernels needs additional information related to mapping // launched kernel instances to subsequence indexes - kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 3}); + internal_buffers_num++; + } + + for (uint32_t i = 0; i < internal_buffers_num; i++) { + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, i}); } if (kernel_type == KernelsTypes::FINALIZATION || kernel_type == KernelsTypes::FINALIZATION_MULTI_TOKENS) { @@ -108,6 +137,15 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const { // Remove unused shape_info argument at finalization stage kernel.params.arguments.erase(kernel.params.arguments.begin()); } + + if (kernel_type == KernelsTypes::SCORES_CALCULATION) { + // The scores kernel needs to know if the current execution mode is mixed or ordinary + // to configure proper memory access + kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0}); + + // Remove unused shape_info argument for scores kernel + kernel.params.arguments.erase(kernel.params.arguments.begin()); + } } return {kd}; @@ -173,7 +211,12 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params& jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", config.group_size)); } - auto sdpa_stage = kernel_idx == KernelsTypes::FINALIZATION || kernel_idx == KernelsTypes::FINALIZATION_MULTI_TOKENS ? 1 : 0; + auto sdpa_stage = 0; + if (kernel_idx == KernelsTypes::FINALIZATION || kernel_idx == KernelsTypes::FINALIZATION_MULTI_TOKENS) { + sdpa_stage = 1; + } else if (kernel_idx == KernelsTypes::SCORES_CALCULATION) { + sdpa_stage = 2; + } jit.AddConstant(MakeJitConstant("SDPA_STAGE_" + std::to_string(sdpa_stage), 1)); if (config.has_const_scale_val) { @@ -190,6 +233,10 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params& jit.Merge(MakeTypeJitConstants(params.inputs[alibi_input_idx].GetDType(), "ALIBI_INPUT")); } + if (params.outputs.size() > 1) { + jit.AddConstant(MakeJitConstant("PAGED_ATTENTION_SCORES_OUTPUT", 1)); + } + if (kernel_idx == KernelsTypes::MULTI_TOKENS || kernel_idx == KernelsTypes::FINALIZATION_MULTI_TOKENS) jit.AddConstant(MakeJitConstant("MULTI_TOKENS_PROCESSING", 1)); @@ -203,18 +250,36 @@ CommonDispatchData PagedAttentionSDPAKernelOpt::SetDefault(const pa_sdpa_params& const auto& input = params.inputs[0]; if (!input.is_dynamic()) { - const size_t sequences_number = input.Batch().v; - const size_t num_of_partitions = CeilDiv(params.max_context_len, seq_len_partition_size); + const size_t total_tokens = input.Batch().v; + const size_t num_of_partitions = CeilDiv(params.conf.paged_attention_max_len, seq_len_partition_size); const size_t heads_num = static_cast(params.conf.heads_num); const size_t head_size = static_cast(params.conf.head_size); - if (kernel_idx == 0) { - dispatch_data.gws = { sequences_number, + if (kernel_idx == KernelsTypes::SINGLE_TOKEN || kernel_idx == KernelsTypes::MULTI_TOKENS) { + dispatch_data.gws = { total_tokens, heads_num, head_size * num_of_partitions }; dispatch_data.lws = { 1, 1, head_size }; + } else if (kernel_idx == KernelsTypes::SCORES_CALCULATION) { + const auto& past_lens = params.inputs[3]; + const auto subsequences_number = past_lens.Batch().v; + + size_t partition_size = 0; + size_t num_of_partitions = 0; + if (params.stage == PagedAttentionStage::PREFILL) { + partition_size = SDPAKernelOpt::get_seq_len_partition_size(params, params.conf.head_size, 1); + } else { + partition_size = seq_len_partition_size; + } + + num_of_partitions = CeilDiv(params.conf.paged_attention_max_len, partition_size); + + dispatch_data.gws = { partition_size * num_of_partitions, + 1, + subsequences_number }; + dispatch_data.lws = { partition_size, 1, 1 }; } else { - dispatch_data.gws = { sequences_number, + dispatch_data.gws = { total_tokens, heads_num, head_size }; dispatch_data.lws = { 1, 1, subgroup_size }; @@ -228,30 +293,39 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons kd.update_dispatch_data_func = [](const Params& params, KernelData& kd) { const auto& prim_params = static_cast(params); - const size_t expected_kernels_num = 4; - OPENVINO_ASSERT(kd.kernels.size() == expected_kernels_num, "[GPU] Invalid kernels size for update dispatch data func of SDPA kernel"); + const auto has_scores_output = prim_params.outputs.size() > 1; + const auto expected_kernels_num = has_scores_output ? KernelsTypes::TOTAL_KERNELS_NUM : KernelsTypes::TOTAL_KERNELS_NUM - 1; + OPENVINO_ASSERT(kd.kernels.size() == static_cast(expected_kernels_num), + "[GPU] Invalid kernels size for update dispatch data func of SDPA kernel"); + + const auto scores_calc_only = prim_params.stage == PagedAttentionStage::PREFILL && has_scores_output; + const auto multi_tokens_mode = prim_params.stage == PagedAttentionStage::MIXED; auto dispatch_data1 = SetDefault(prim_params, KernelsTypes::SINGLE_TOKEN); kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.global = dispatch_data1.gws; kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.local = dispatch_data1.lws; - kd.kernels[KernelsTypes::SINGLE_TOKEN].skip_execution = prim_params.multi_tokens_mode; + kd.kernels[KernelsTypes::SINGLE_TOKEN].skip_execution = multi_tokens_mode || scores_calc_only; kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.global = dispatch_data1.gws; kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.local = dispatch_data1.lws; - kd.kernels[KernelsTypes::MULTI_TOKENS].skip_execution = !prim_params.multi_tokens_mode; + kd.kernels[KernelsTypes::MULTI_TOKENS].skip_execution = !multi_tokens_mode || scores_calc_only; - const auto& input = prim_params.inputs[0]; - const size_t sequences_number = input.Batch().v; - const size_t num_of_partitions = CeilDiv(prim_params.max_context_len, seq_len_partition_size); + size_t partition_size = 0; + if (prim_params.stage == PagedAttentionStage::PREFILL) { + partition_size = SDPAKernelOpt::get_seq_len_partition_size(params, prim_params.conf.head_size, 1); + } else { + partition_size = seq_len_partition_size; + } + const size_t num_of_partitions = CeilDiv(prim_params.conf.paged_attention_max_len, partition_size); auto dispatch_data2 = SetDefault(prim_params, KernelsTypes::FINALIZATION); kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.global = dispatch_data2.gws; kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.local = dispatch_data2.lws; - kd.kernels[KernelsTypes::FINALIZATION].skip_execution = num_of_partitions == 1 || prim_params.multi_tokens_mode; + kd.kernels[KernelsTypes::FINALIZATION].skip_execution = num_of_partitions == 1 || multi_tokens_mode || scores_calc_only; kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.global = dispatch_data2.gws; kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.local = dispatch_data2.lws; - kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].skip_execution = num_of_partitions == 1 || !prim_params.multi_tokens_mode; + kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].skip_execution = num_of_partitions == 1 || !multi_tokens_mode || scores_calc_only; ScalarDescriptor num_of_partitions_scalar; num_of_partitions_scalar.t = ScalarDescriptor::Types::UINT32; @@ -261,23 +335,63 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.scalars.resize(1); kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.scalars[0] = num_of_partitions_scalar; + if (has_scores_output) { + auto dispatch_data = SetDefault(prim_params, KernelsTypes::SCORES_CALCULATION); + kd.kernels[KernelsTypes::SCORES_CALCULATION].params.workGroups.global = dispatch_data.gws; + kd.kernels[KernelsTypes::SCORES_CALCULATION].params.workGroups.local = dispatch_data.lws; + kd.kernels[KernelsTypes::SCORES_CALCULATION].skip_execution = false; + + ScalarDescriptor is_mixed_mode; + is_mixed_mode.t = ScalarDescriptor::Types::UINT32; + is_mixed_mode.v.u32 = static_cast(multi_tokens_mode); + kd.kernels[KernelsTypes::SCORES_CALCULATION].params.scalars.resize(1); + kd.kernels[KernelsTypes::SCORES_CALCULATION].params.scalars[0] = is_mixed_mode; + } + + const auto& input = prim_params.inputs[0]; + const size_t total_tokens = input.Batch().v; + auto buf_dt_size = BytesPerElement(softmax_acc_dt); - auto buf_elements_count = sequences_number * prim_params.conf.heads_num * num_of_partitions; + auto buf_elements_count = total_tokens * prim_params.conf.heads_num * num_of_partitions; auto buf_size = buf_elements_count * buf_dt_size; auto tmp_out_dt_size = BytesPerElement(softmax_acc_dt); - auto tmp_out_elements_count = sequences_number * prim_params.conf.heads_num * prim_params.conf.head_size * num_of_partitions; + auto tmp_out_elements_count = total_tokens * prim_params.conf.heads_num * prim_params.conf.head_size * num_of_partitions; auto tmp_out_size = tmp_out_elements_count * tmp_out_dt_size; kd.internalBufferSizes.clear(); - kd.internalBufferSizes.push_back(buf_size); - kd.internalBufferSizes.push_back(buf_size); - kd.internalBufferSizes.push_back(tmp_out_size); + + if (has_scores_output) { + const auto& past_lens = prim_params.inputs[3]; + auto subsequences_number = past_lens.Batch().v; + auto softmax_buf_dt_size = BytesPerElement(softmax_acc_dt); + + auto softmax_buf_elements_count = subsequences_number * prim_params.conf.heads_num * num_of_partitions * partition_size; + auto softmax_buf_size = softmax_buf_elements_count * softmax_buf_dt_size; + + // Softmax intermediate output + kd.internalBufferSizes.push_back(softmax_buf_size); + // Precalculated accumulated sequence length offsets for each subsequence + kd.internalBufferSizes.push_back(subsequences_number * BytesPerElement(Datatype::INT32)); + + if (prim_params.stage == PagedAttentionStage::PREFILL) { + // Recalculate buf_size as in case of PREFILL stage it's not needed to allocate buffer per each input token + buf_elements_count = subsequences_number * prim_params.conf.heads_num * num_of_partitions; + buf_size = buf_elements_count * buf_dt_size; + + // Intermediate tmp output buffer is not used for PREFILL stage + tmp_out_size = tmp_out_dt_size; + } + } + + kd.internalBufferSizes.push_back(buf_size); // softmax exp_sums + kd.internalBufferSizes.push_back(buf_size); // softmax max_logits + kd.internalBufferSizes.push_back(tmp_out_size); // intermediate output kd.internalBufferDataType = softmax_acc_dt; - if (prim_params.multi_tokens_mode) { + if (multi_tokens_mode) { auto buf_dt_size = BytesPerElement(Datatype::INT32); - auto buf_elements_count = sequences_number; + auto buf_elements_count = total_tokens; auto buf_size = Align(buf_elements_count * buf_dt_size, BytesPerElement(softmax_acc_dt)); kd.internalBufferSizes.push_back(buf_size); } diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.h index a2456ccd9e2af5..a52571b03691df 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.h @@ -9,11 +9,17 @@ namespace kernel_selector { +enum PagedAttentionStage { + GENERATE = 0, + PREFILL = 1, + MIXED = 2, + UNKNOWN = 3 +}; + struct pa_sdpa_params : base_params { pa_sdpa_params() : base_params(KernelType::PA_SDPA) {} - bool multi_tokens_mode = false; - size_t max_context_len = 0; + PagedAttentionStage stage = PagedAttentionStage::UNKNOWN; sdpa_configuration conf; }; diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h index 5cd9c384ff2709..8fcc4a16692d6c 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h @@ -97,6 +97,7 @@ struct sdpa_configuration { bool is_paged_attention = false; int64_t paged_attention_aligned_seq_len = -1; int64_t paged_attention_block_size = 0; + int64_t paged_attention_max_len = 0; bool has_const_scale_val = false; float scale_val = 0.f; }; diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp index 4e71064efbc895..4c23d4de4fd68d 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp @@ -21,38 +21,11 @@ enum KernelsTypes { constexpr size_t subgroup_size = 16; } // namespace -static size_t get_sg_number_scale_factor(const sdpa_params& sdpa_params, size_t kernel_type) { - const size_t optimal_scale_factor = 2; - if (kernel_type == KernelsTypes::MULTI_TOKENS) { - if (sdpa_params.conf.head_size * optimal_scale_factor <= sdpa_params.engineInfo.maxWorkGroupSize) { - return optimal_scale_factor; - } - } else if (kernel_type == KernelsTypes::SINGLE_TOKEN) { - if (sdpa_params.conf.head_size * optimal_scale_factor <= sdpa_params.engineInfo.maxWorkGroupSize && - sdpa_params.conf.head_size * optimal_scale_factor / subgroup_size <= subgroup_size) { - return optimal_scale_factor; - } - } - - return 1; -} - static size_t get_target_seq_len_block_size() { const size_t block_size = 16; return block_size; } -static size_t get_seq_len_partition_size(const sdpa_params& sdpa_params, size_t kernel_type) { - size_t seq_len = 0; - if (kernel_type == KernelsTypes::MULTI_TOKENS) { - seq_len = sdpa_params.conf.head_size * get_sg_number_scale_factor(sdpa_params, kernel_type); - } else { - seq_len = 256; - } - - return seq_len; -} - static Datatype get_softmax_acc_type() { return Datatype::F32; } @@ -71,7 +44,7 @@ static size_t get_partitions_num(const sdpa_params& sdpa_params, size_t kernel_t TransposedDimensionAccessHelperBase dims_k(sdpa_params.inputs[1], sdpa_params.input1_order); auto source_seq_len = dims_k.y_dim().v; - return CeilDiv(source_seq_len, get_seq_len_partition_size(sdpa_params, kernel_type)); + return CeilDiv(source_seq_len, SDPAKernelOpt::get_seq_len_partition_size(sdpa_params, sdpa_params.conf.head_size, kernel_type)); } static std::vector get_internal_buffer_sizes(const sdpa_params& sdpa_params, size_t kernel_type) { @@ -130,6 +103,33 @@ static std::string GetKernelName(std::string base_name, KernelsTypes type, const return kernel_name; } +size_t SDPAKernelOpt::get_sg_number_scale_factor(const Params& params, size_t head_size, size_t kernel_type) { + const size_t optimal_scale_factor = 2; + if (kernel_type == KernelsTypes::MULTI_TOKENS) { + if (head_size * optimal_scale_factor <= params.engineInfo.maxWorkGroupSize) { + return optimal_scale_factor; + } + } else if (kernel_type == KernelsTypes::SINGLE_TOKEN) { + if (head_size * optimal_scale_factor <= params.engineInfo.maxWorkGroupSize && + head_size * optimal_scale_factor / subgroup_size <= subgroup_size) { + return optimal_scale_factor; + } + } + + return 1; +} + +size_t SDPAKernelOpt::get_seq_len_partition_size(const Params& params, size_t head_size, size_t kernel_type) { + size_t seq_len = 0; + if (kernel_type == KernelsTypes::MULTI_TOKENS) { + seq_len = head_size * get_sg_number_scale_factor(params, head_size, kernel_type); + } else { + seq_len = 256; + } + + return seq_len; +} + ParamsKey SDPAKernelOpt::GetSupportedKey() const { ParamsKey k; k.EnableInputDataType(Datatype::INT8); @@ -176,14 +176,14 @@ JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t ke const auto& config = params.conf; jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size)); jit.AddConstant(MakeJitConstant("HEAD_SIZE", config.head_size)); - jit.AddConstant(MakeJitConstant("SEQ_LEN_PARTITION_SIZE", get_seq_len_partition_size(params, kernel_idx))); + jit.AddConstant(MakeJitConstant("SEQ_LEN_PARTITION_SIZE", get_seq_len_partition_size(params, config.head_size, kernel_idx))); auto target_seq_len_block_size = kernel_idx == KernelsTypes::SINGLE_TOKEN ? 1 : get_target_seq_len_block_size(); jit.AddConstant(MakeJitConstant("TARGET_SEQ_LEN_BLOCK_SIZE", target_seq_len_block_size)); auto sdpa_stage = kernel_idx == KernelsTypes::FINALIZATION ? 1 : 0; jit.AddConstant(MakeJitConstant("SDPA_STAGE_" + std::to_string(sdpa_stage), 1)); - jit.AddConstant(MakeJitConstant("SG_SCALE_FACTOR", get_sg_number_scale_factor(params, kernel_idx))); + jit.AddConstant(MakeJitConstant("SG_SCALE_FACTOR", get_sg_number_scale_factor(params, config.head_size, kernel_idx))); if (params.conf.is_paged_attention) { if (params.conf.has_alibi_input) { @@ -196,6 +196,10 @@ JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t ke } else { jit.AddConstant(MakeJitConstant("HAS_SCALE_INPUT", 1)); } + + if (params.outputs.size() > 1) { + jit.AddConstant(MakeJitConstant("PAGED_ATTENTION_SCORES_OUTPUT", 1)); + } } else if (params.inputs.size() <= 4) { jit.AddConstant(MakeJitConstant("STATIC_SCALE_VALUE_INV", std::sqrt(static_cast(params.conf.head_size)))); jit.AddConstant(MakeJitConstant("STATIC_SCALE_VALUE", 1.0f / std::sqrt(static_cast(params.conf.head_size)))); @@ -218,11 +222,11 @@ CommonDispatchData SDPAKernelOpt::SetDefault(const sdpa_params& params, size_t k if (params.conf.is_paged_attention) { OPENVINO_ASSERT(kernel_idx == KernelsTypes::MULTI_TOKENS); - const size_t sg_num_scale = get_sg_number_scale_factor(params, kernel_idx); const size_t heads_num = static_cast(params.conf.heads_num); + const size_t head_size = static_cast(params.conf.head_size); + const size_t sg_num_scale = get_sg_number_scale_factor(params, head_size, kernel_idx); const size_t target_seq_len_block_size = get_target_seq_len_block_size(); const size_t target_seq_len = static_cast(params.conf.paged_attention_aligned_seq_len); - const size_t head_size = static_cast(params.conf.head_size); dispatch_data.gws = { heads_num, CeilDiv(target_seq_len, target_seq_len_block_size), @@ -243,13 +247,13 @@ CommonDispatchData SDPAKernelOpt::SetDefault(const sdpa_params& params, size_t k const size_t target_seq_len_block_size = kernel_idx == 1 ? get_target_seq_len_block_size() : 1; if (kernel_idx == KernelsTypes::SINGLE_TOKEN) { - const size_t sg_num_scale = get_sg_number_scale_factor(params, kernel_idx); + const size_t sg_num_scale = get_sg_number_scale_factor(params, head_size, kernel_idx); dispatch_data.gws = { batch_size * heads_num, CeilDiv(target_seq_len, target_seq_len_block_size), head_size * num_of_partitions * sg_num_scale }; dispatch_data.lws = { 1, 1, head_size * sg_num_scale }; } else if (kernel_idx == KernelsTypes::MULTI_TOKENS) { - const size_t sg_num_scale = get_sg_number_scale_factor(params, kernel_idx); + const size_t sg_num_scale = get_sg_number_scale_factor(params, head_size, kernel_idx); dispatch_data.gws = { batch_size * heads_num, CeilDiv(target_seq_len, target_seq_len_block_size), head_size * sg_num_scale }; @@ -317,7 +321,7 @@ KernelsData SDPAKernelOpt::GetKernelsData(const Params& params) const { false, inputs_num, GetFusedPrimitiveInputsCount(params), - static_cast(prim_params.outputs.size()), + 1 /* number_of_outputs */, prim_params.is_shape_agnostic); auto beam_table_idx = prim_params.inputs.size(); @@ -339,6 +343,19 @@ KernelsData SDPAKernelOpt::GetKernelsData(const Params& params) const { kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); + if (prim_params.conf.is_paged_attention && prim_params.outputs.size() > 1) { + // Intermediate buffers for PagedAttention scores calculation: + // softmax_results, subsequence_offsets, exp_sums, max_logits, tmp_out + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 3}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 4}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 5}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 6}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 7}); + + // Scalar used for proper offset calculation of intermediate data buffers + kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0}); + } + const auto buf_sizes = get_internal_buffer_sizes(prim_params, kernel_idx); if (!prim_params.conf.is_paged_attention) { kd.internalBufferSizes.clear(); @@ -379,6 +396,15 @@ void SDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) const { kernel_data.kernels[0].params.workGroups.global = dispatch_data.gws; kernel_data.kernels[0].params.workGroups.local = dispatch_data.lws; kernel_data.kernels[0].skip_execution = false; + + if (prim_params.outputs.size() > 1) { + const auto max_seq_len = prim_params.conf.paged_attention_max_len; + const auto seq_len_partition_size = get_seq_len_partition_size(params, prim_params.conf.head_size, KernelsTypes::MULTI_TOKENS); + + kernel_data.kernels[0].params.scalars.resize(1); + kernel_data.kernels[0].params.scalars[0].t = ScalarDescriptor::Types::UINT32; + kernel_data.kernels[0].params.scalars[0].v.u32 = static_cast(Align(max_seq_len, seq_len_partition_size)); + } } else { const auto num_of_partitions = get_partitions_num(prim_params, KernelsTypes::SINGLE_TOKEN); const auto buf_sizes = get_internal_buffer_sizes(prim_params, KernelsTypes::SINGLE_TOKEN); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.h index 8d7279f5546112..a4d351498d7075 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.h @@ -17,6 +17,9 @@ class SDPAKernelOpt : public SDPAKernelBase { KernelsPriority GetKernelsPriority(const Params& params) const override; ParamsKey GetSupportedKey() const override; + static size_t get_sg_number_scale_factor(const Params& params, size_t head_size, size_t kernel_type); + static size_t get_seq_len_partition_size(const Params& params, size_t head_size, size_t kernel_type); + protected: bool Validate(const Params& p) const override; void GetUpdateDispatchDataFunc(KernelData& kd) const override; diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp index 7425b096b6d324..d82d3a66fed7f7 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp @@ -61,10 +61,13 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared OPENVINO_ASSERT(alibi_const != nullptr); prim.has_alibi = ov::shape_size(alibi_const->get_output_shape(0)) > 0; + prim.num_outputs = 1; if (op->get_output_size() > 1) { const auto scores_output_idx = 1; const auto& users = op->get_output_target_inputs(scores_output_idx); - OPENVINO_ASSERT(users.size() == 0, "[GPU] PagedAttention implementation doesn't support scores output yet"); + if (users.size() > 0) { + prim.num_outputs++; // Add scores output + } } p.add_primitive(*op, prim); diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp new file mode 100644 index 00000000000000..a32ef3325cd9bc --- /dev/null +++ b/src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp @@ -0,0 +1,687 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test_utils.h" +#include "random_generator.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace cldnn; +using namespace ov::intel_gpu; +using namespace ::tests; + +/* +* PagedAttention inputs: +* [0]: query +* shape: [batch_size_in_tokens, num_heads * head_size], type: f16 +* [1]: key +* shape: [batch_size_in_tokens, num_kv_heads * head_size], type: f16 +* [2]: value  +* shape: [batch_size_in_tokens, num_kv_heads * head_size], type: f16 +* [3]: key_cache +* shape: [num_blocks, num_kv_heads, head_size, block_size], type: f16 +* [4]: value_cache +* shape: [num_blocks, num_kv_heads, block_size, head_size], type: f16 +* [5]: past_lens +* shape: [batch_size_in_sequences], type: i32 +* [6]: subsequence_begins +* shape: [batch_size_in_sequences + 1], type: i32 +* [7]: block_indices +* Shape: [num_blocks], type: i32 +* [8]: block_indices_begins +* Shape: [batch_size_in_sequences + 1], type: i32 +* [9]: scale, optional +* [10]: sliding_window, optional +* [11]: alibi_slopes, optional +* [12]: max_context_len +* shape: [], type: i32 +*/ + +struct SubsequenceDescriptor { + int num_tokens; + int past_len; +}; + +struct PagedAttentionManager { + int num_heads; + int head_size; + int block_size; + std::vector subsequence_descs; + + // per-subsequence QKV inputs + std::vector> query_data; // {[1, num_tokens, num_heads, head_size], ..} + std::vector> key_data; // {[1, past_len + num_tokens, num_heads, head_size], ..} + std::vector> value_data; // {[1, past_len + num_tokens, num_heads, head_size], ..} + + // common PA inputs + std::vector past_lens; + std::vector subsequence_begins; + std::vector block_indices; + std::vector block_indices_begins; + std::vector max_context_len; + + cldnn::engine& test_engine; + cldnn::stream& test_stream; + tests::random_generator& rg; + + PagedAttentionManager(tests::random_generator& rg, + cldnn::engine& engine, + cldnn::stream& stream, + const std::vector& subsequence_descs, + int num_heads, + int head_size, + int block_size) + : num_heads(num_heads) + , head_size(head_size) + , block_size(block_size) + , subsequence_descs(subsequence_descs) + , test_engine(engine) + , test_stream(stream) + , rg(rg) { + // init subsequence_begins and block_indices_begins + subsequence_begins.push_back(0); + block_indices_begins.push_back(0); + + int max_len = 0; + for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { + const auto& subsequence_desc = subsequence_descs[i]; + max_len = std::max(max_len, subsequence_desc.num_tokens + subsequence_desc.past_len); + + query_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens, head_size)); + key_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, head_size)); + value_data.push_back(generate_input_data(rg, num_heads, subsequence_desc.num_tokens + subsequence_desc.past_len, head_size)); + + past_lens.push_back(subsequence_desc.past_len); + int subsequence_start_pos = subsequence_begins[i]; + int subsequence_end_pos = subsequence_start_pos + subsequence_desc.num_tokens; + subsequence_begins.push_back(subsequence_end_pos); + + int subsequence_length = subsequence_desc.num_tokens + subsequence_desc.past_len; + int required_blocks = ceil_div(subsequence_length, block_size); + int start_block_idx = block_indices.empty() ? 0 : block_indices.back() + 1; + int end_block_idx = start_block_idx + required_blocks; + for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { + block_indices.push_back(block_idx); + } + + int block_indices_start_pos = block_indices_begins[i]; + int block_indices_end_pos = block_indices_start_pos + required_blocks; + block_indices_begins.push_back(block_indices_end_pos); + } + max_context_len.push_back(max_len); + } + + memory::ptr get_query_memory() { + return get_QKV_memory(query_data, false); + } + + memory::ptr get_key_memory() { + return get_QKV_memory(key_data, true); + } + + memory::ptr get_value_memory() { + return get_QKV_memory(value_data, true); + } + + memory::ptr get_key_cache_memory() { + auto num_blocks = block_indices.back() + 1; + auto key_cache_shape = ov::PartialShape{ num_blocks, num_heads, head_size, block_size }; + auto key_cache_layout = layout{ key_cache_shape, data_types::f16, format::bfyx }; + auto memory = test_engine.allocate_memory(key_cache_layout); + + for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { + int past_len = subsequence_descs[i].past_len; + if (past_len != 0) { + int blocks_num = ceil_div(past_len, block_size); + int start_block_idx = block_indices[block_indices_begins[i]]; + for (int block_idx = 0; block_idx < blocks_num; block_idx++) { + int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size + : block_size; + for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + for (int head_size_idx = 0; head_size_idx < head_size; head_size_idx++) { + size_t input_token_offset = block_idx * block_size + token_idx; + ov::float16* data_ptr = key_data[i].data() + + input_token_offset * num_heads * head_size + + head_idx * head_size + head_size_idx; + + // shape: [num_blocks, num_heads, head_size, block_size] + size_t output_offset = (start_block_idx + block_idx) * num_heads * head_size * block_size + + head_idx * head_size * block_size + + head_size_idx * block_size + + token_idx; + + set_values(test_stream, memory, data_ptr, 1, output_offset); + } + } + } + } + } + } + + return memory; + } + + memory::ptr get_value_cache_memory() { + auto num_blocks = block_indices.back() + 1; + auto value_cache_shape = ov::PartialShape{ num_blocks, num_heads, block_size, head_size }; + auto value_cache_layout = layout{ value_cache_shape, data_types::f16, format::bfyx }; + auto memory = test_engine.allocate_memory(value_cache_layout); + + for (int i = 0; i < static_cast(subsequence_descs.size()); i++) { + int past_len = subsequence_descs[i].past_len; + if (past_len != 0) { + int blocks_num = ceil_div(past_len, block_size); + int start_block_idx = block_indices[block_indices_begins[i]]; + for (int block_idx = 0; block_idx < blocks_num; block_idx++) { + int last_token_idx = block_idx == blocks_num - 1 ? past_len % block_size + : block_size; + for (int token_idx = 0; token_idx < last_token_idx; token_idx++) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + size_t input_token_offset = block_idx * block_size + token_idx; + ov::float16* data_ptr = value_data[i].data() + + input_token_offset * num_heads * head_size + + head_idx * head_size; + + // shape: [num_blocks, num_heads, block_size, head_size] + size_t output_offset = (start_block_idx + block_idx) * num_heads * block_size * head_size + + head_idx * block_size * head_size + + token_idx * head_size; + + set_values(test_stream, memory, data_ptr, head_size, output_offset); + } + } + } + } + } + + return memory; + } + + memory::ptr get_past_lens_memory() { + return get_memory_from_vec(past_lens); + } + + memory::ptr get_subsequence_begins_memory() { + return get_memory_from_vec(subsequence_begins); + } + + memory::ptr get_block_indices_memory() { + return get_memory_from_vec(block_indices); + } + + memory::ptr get_block_indices_begins_memory() { + return get_memory_from_vec(block_indices_begins); + } + + memory::ptr get_scale_memory() { + std::vector scale = { ov::float16(get_default_scale()) }; + return get_memory_from_vec(scale); + } + + memory::ptr get_sliding_window_memory() { + std::vector sliding_window = { 0 }; + return get_memory_from_vec(sliding_window); + } + + memory::ptr get_alibi_memory() { + std::vector alibi; + return get_memory_from_vec(alibi); + } + + memory::ptr get_max_context_len_memory() { + return get_memory_from_vec(max_context_len); + } + + float get_default_scale() { + return static_cast(1.f / std::sqrt(head_size)); + } + +private: + template + memory::ptr get_memory_from_vec(std::vector& input_data) { + auto data_size = input_data.empty() ? 1 : input_data.size(); + auto shape = ov::PartialShape{ static_cast(data_size) }; + auto layout = cldnn::layout{ shape, ov::element::from(), format::bfyx }; + auto memory = test_engine.allocate_memory(layout); + + if (input_data.empty()) { + auto shape = ov::PartialShape{0}; + auto layout = cldnn::layout{ shape, ov::element::from(), format::bfyx }; + return test_engine.reinterpret_buffer(*memory, layout); + } + + set_values(test_stream, memory, input_data.data(), input_data.size(), 0); + + return memory; + } + + memory::ptr get_QKV_memory(std::vector>& input_data, bool skip_past_len) { + int total_tokens = 0; + for (const auto& subsequence_desc : subsequence_descs) + total_tokens += subsequence_desc.num_tokens; + + auto query_shape = ov::PartialShape{ total_tokens, num_heads * head_size }; + auto query_layout = layout{ query_shape, data_types::f16, format::bfyx }; + auto memory = test_engine.allocate_memory(query_layout); + + for (int subsequence_idx = 0; subsequence_idx < static_cast(subsequence_descs.size()); subsequence_idx++) { + for (int token_idx = 0; token_idx < subsequence_descs[subsequence_idx].num_tokens; token_idx++) { + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + size_t input_token_offset = token_idx; + // as generated data stored in vectors includes past_len, ignore it for KV inputs + if (skip_past_len) + input_token_offset += subsequence_descs[subsequence_idx].past_len; + + ov::float16* data_ptr = input_data[subsequence_idx].data() + + input_token_offset * num_heads * head_size + + head_idx * head_size; + + size_t output_token_offset = subsequence_begins[subsequence_idx] + token_idx; + size_t output_offset = output_token_offset * num_heads * head_size + + head_idx * head_size; + + set_values(test_stream, memory, data_ptr, head_size, output_offset); + } + } + } + + return memory; + } + + template + static void set_values(stream& stream, memory::ptr mem, T* vals, size_t size, size_t dst_offset) { + mem_lock mem_ptr(mem, stream); + for (size_t i = 0; i < size; i++) { + mem_ptr[dst_offset + i] = vals[i]; + } + } + + static std::vector generate_input_data(tests::random_generator& rg, size_t num_heads, size_t tokens_num, size_t head_size) { + const size_t total_elements_num = tokens_num * num_heads * head_size; + auto data = rg.generate_random_1d(total_elements_num, -1, 1); + + return data; + } +}; + +struct PagedAttentionReference { + PagedAttentionReference(PagedAttentionManager& pam) + : pam(pam) + , test_engine(pam.test_engine) + , test_stream(pam.test_stream) {} + + std::pair, std::vector> get_reference() { + std::vector ref_data_output; + std::vector ref_scores_output; + + for (size_t i = 0; i < pam.subsequence_descs.size(); i++) { + const auto& subsequence_desc = pam.subsequence_descs[i]; + const auto kv_seq_len = subsequence_desc.num_tokens + subsequence_desc.past_len; + auto subsequence_ref_results = run_reference(pam.query_data[i], + pam.key_data[i], + pam.value_data[i], + subsequence_desc.num_tokens, + kv_seq_len, + pam.num_heads, + pam.head_size, + pam.get_default_scale()); + + // concatenate all subsequences into one vector + ref_data_output.insert(ref_data_output.end(), + subsequence_ref_results.first.begin(), + subsequence_ref_results.first.end()); + ref_scores_output.insert(ref_scores_output.end(), + subsequence_ref_results.second.begin(), + subsequence_ref_results.second.end()); + } + + return { ref_data_output, ref_scores_output }; + } + +private: + std::pair, std::vector> + run_reference(const std::vector& query_data, + const std::vector& key_data, + const std::vector& value_data, + int num_queries, + int num_keys, + int num_heads, + int head_size, + float scale) { + auto query_shape = ov::PartialShape{1, num_queries, num_heads, head_size}; + auto key_shape = ov::PartialShape{1, num_keys, num_heads, head_size}; + auto value_shape = ov::PartialShape{1, num_keys, num_heads, head_size}; + + auto query_layout = layout{query_shape, data_types::f16, format::bfyx}; + auto key_layout = layout{key_shape, data_types::f16, format::bfyx}; + auto value_layout = layout{value_shape, data_types::f16, format::bfyx}; + + OPENVINO_ASSERT(query_layout.count() == query_data.size()); + OPENVINO_ASSERT(key_layout.count() == key_data.size()); + OPENVINO_ASSERT(value_layout.count() == value_data.size()); + + auto query_mem = test_engine.allocate_memory(query_layout); + auto key_mem = test_engine.allocate_memory(key_layout); + auto value_mem = test_engine.allocate_memory(value_layout); + auto mask_mem = get_mask_mem(num_queries, num_keys, num_heads); + + set_values(query_mem, query_data); + set_values(key_mem, key_data); + set_values(value_mem, value_data); + + topology topology; + topology.add(input_layout("query", query_layout), + input_layout("key", key_layout), + input_layout("value", value_layout), + data("mask", mask_mem), + permute("query_transposed", input_info("query"), {0, 2, 1, 3}), + permute("key_transposed", input_info("key"), {0, 2, 1, 3}), + permute("value_transposed", input_info("value"), {0, 2, 1, 3}), + gemm("qk_gemm", { input_info("query_transposed"), input_info("key_transposed") }, data_types::f16, false, true, scale), + eltwise("eltwise", { input_info("qk_gemm"), input_info("mask") }, eltwise_mode::sum), + softmax("softmax", input_info("eltwise"), -1), + gemm("qkv_gemm", { input_info("softmax"), input_info("value_transposed") }, data_types::f16, false, false), + permute("qkv_gemm_transposed", input_info("qkv_gemm"), {0, 2, 1, 3}), + reorder("output_data", input_info("qkv_gemm_transposed"), format::bfyx, data_types::f16), + reorder("scores_data", input_info("softmax"), format::bfyx, data_types::f16) + ); + + ExecutionConfig config = get_test_default_config(test_engine); + config.set_property(ov::intel_gpu::optimize_data(true)); + config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); + + network::ptr network = get_network(test_engine, topology, config, get_test_stream_ptr(), false); + network->set_input_data("query", query_mem); + network->set_input_data("key", key_mem); + network->set_input_data("value", value_mem); + + auto outputs = network->execute(); + + auto output_data_mem = outputs.at("output_data").get_memory(); + auto output_scores_mem = outputs.at("scores_data").get_memory(); + + return { get_output_data_vec(output_data_mem, num_queries, head_size, num_heads), + get_output_scores_vec(output_scores_mem, num_queries, num_keys, num_heads) }; + } + + std::vector get_output_scores_vec(memory::ptr scores_output, + int num_queries, + int num_keys, + int num_heads) { + OPENVINO_ASSERT(scores_output->count() == static_cast(num_heads * num_queries * num_keys)); + + std::vector output_scores(num_keys, 0); + mem_lock mem_ptr(scores_output, test_stream); + for (int head_idx = 0; head_idx < num_heads; head_idx++) { + for (int score_idx = 0; score_idx < num_keys; score_idx++) { + output_scores[score_idx] += mem_ptr[head_idx * num_queries * num_keys + + (num_queries - 1) * num_keys + + score_idx]; + } + } + + return output_scores; + } + + std::vector get_output_data_vec(memory::ptr data_output, + int num_queries, + int head_size, + int num_heads) { + OPENVINO_ASSERT(data_output->count() == static_cast(num_queries * num_heads * head_size)); + + std::vector output_data(data_output->count()); + mem_lock mem_ptr(data_output, test_stream); + for (size_t i = 0; i < data_output->count(); i++) + output_data[i] = mem_ptr[i]; + + return output_data; + } + + memory::ptr get_mask_mem(int num_queries, int num_keys, int num_heads) { + /* + * Two kinds of masks: + * + * Case 1 (N == K): + * num_queries = N + * num_keys = K = N + * head_size = H + * Q [N, H] * K[H, N] + * QK [N, N] + * 0 1 N + * 0 [ 0, MIN, .., MIN ] + * 1 [ 0, 0, .., MIN ] + * [ .., .., .., MIN ] + * N [ 0, 0, .., 0 ] + * + * Case 2 (N != K): + * num_queries = N + * num_keys = K + * head_size = H + * past_len = P = K - N + 1 + * Q [N, H] * K[H, K] + * QK [N, K] + * 0 1 2 P .. K + * 0 [ 0, 0, 0, MIN, MIN, MIN ] + * 1 [ 0, 0, 0, 0, MIN, MIN ] + * [ .., .., .., .., .., MIN ] + * N [ 0, 0, 0, 0, .., 0 ] + * + * Shapes: + * Q [1, num_heads, num_queries, head_size] + * K [1, num_heads, head_size, num_keys] + * Q*K [1, num_heads, num_queries, num_keys] + */ + + auto mask_shape = ov::PartialShape{ 1, 1, num_queries, num_keys }; + auto mask_layout = layout{mask_shape, data_types::f16, format::bfyx}; + auto mask_mem = test_engine.allocate_memory(mask_layout); + + int past_len = num_keys - num_queries + 1; + mem_lock mem_ptr(mask_mem, test_stream); + for (int i = 0; i < num_queries; i++) { + for (int j = 0; j < num_keys; j++) { + mem_ptr[i * num_keys + j] = j >= past_len + i ? std::numeric_limits::lowest() + : ov::float16(0.f); + } + } + + return mask_mem; + } + + + PagedAttentionManager& pam; + cldnn::engine& test_engine; + cldnn::stream& test_stream; +}; + +template +struct PagedAttentionTest : public ::testing::TestWithParam { +public: + random_generator rg; + cldnn::engine& engine = get_test_engine(); + float tolerance = 2e-3; + + void SetUp() override { + rg.set_seed(GET_SUITE_NAME); + } + + void execute(T& p) { + PagedAttentionManager pam(rg, get_test_engine(), get_test_stream(), p.subsequences, p.num_heads, p.head_size, p.block_size); + + auto query_mem = pam.get_query_memory(); + auto key_mem = pam.get_key_memory(); + auto value_mem = pam.get_value_memory(); + + auto key_cache_mem = pam.get_key_cache_memory(); + auto value_cache_mem = pam.get_value_cache_memory(); + + auto past_lens_mem = pam.get_past_lens_memory(); + auto subsequence_begins_mem = pam.get_subsequence_begins_memory(); + auto block_indices_mem = pam.get_block_indices_memory(); + auto block_indices_begins_mem = pam.get_block_indices_begins_memory(); + + auto scale_mem = pam.get_scale_memory(); + auto sliding_window_mem = pam.get_sliding_window_memory(); + auto alibi_mem = pam.get_alibi_memory(); + auto max_context_len_mem = pam.get_max_context_len_memory(); + + auto query_layout = query_mem->get_layout(); + auto key_layout = key_mem->get_layout(); + auto value_layout = value_mem->get_layout(); + auto key_cache_layout = key_cache_mem->get_layout(); + auto value_cache_layout = value_cache_mem->get_layout(); + auto past_lens_layout = past_lens_mem->get_layout(); + auto subsequence_begins_layout = subsequence_begins_mem->get_layout(); + auto block_indices_layout = block_indices_mem->get_layout(); + auto block_indices_begins_layout = block_indices_begins_mem->get_layout(); + auto scale_layout = scale_mem->get_layout(); + auto sliding_window_layout = sliding_window_mem->get_layout(); + auto alibi_layout = alibi_mem->get_layout(); + auto max_context_len_layout = max_context_len_mem->get_layout(); + + // make layouts dynamic + query_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.head_size }); + key_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.head_size }); + value_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads * p.head_size }); + key_cache_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads, p.head_size, p.block_size }); + value_cache_layout.set_partial_shape(ov::PartialShape{ -1, p.num_heads, p.block_size, p.head_size }); + past_lens_layout.set_partial_shape(ov::PartialShape{ -1 }); + subsequence_begins_layout.set_partial_shape(ov::PartialShape{ -1 }); + block_indices_layout.set_partial_shape(ov::PartialShape{ -1 }); + block_indices_begins_layout.set_partial_shape(ov::PartialShape{ -1 }); + + auto pa_prim = paged_attention("paged_attention", { input_info("query"), + input_info("key"), + input_info("value"), + input_info("key_cache"), + input_info("value_cache"), + input_info("past_lens"), + input_info("subsequence_begins"), + input_info("block_indices"), + input_info("block_indices_begins"), + input_info("scale"), + input_info("sliding_window"), + input_info("alibi"), + input_info("max_context_len") }); + + pa_prim.head_size = p.head_size; + pa_prim.kv_heads_num = p.num_heads; + pa_prim.heads_num = p.num_heads; + pa_prim.scale_val = pam.get_default_scale(); + pa_prim.has_alibi = false; + pa_prim.num_outputs = p.scores_output ? 2 : 1; + + topology topology; + topology.add( + input_layout("query", query_layout), + input_layout("key", key_layout), + input_layout("value", value_layout), + input_layout("key_cache", key_cache_layout), + input_layout("value_cache", value_cache_layout), + input_layout("past_lens", past_lens_layout), + input_layout("subsequence_begins", subsequence_begins_layout), + input_layout("block_indices", block_indices_layout), + input_layout("block_indices_begins", block_indices_begins_layout), + input_layout("scale", scale_layout), + input_layout("sliding_window", sliding_window_layout), + input_layout("alibi", alibi_layout), + input_layout("max_context_len", max_context_len_layout), + pa_prim, + reorder("output_data", input_info("paged_attention", 0), format::bfyx, data_types::f16) + ); + + if (p.scores_output) { + topology.add(reorder("output_scores", input_info("paged_attention", 1), format::bfyx, data_types::f16)); + } + + ExecutionConfig config = get_test_default_config(get_test_engine()); + config.set_property(ov::intel_gpu::optimize_data(true)); + config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); + + network::ptr network = get_network(get_test_engine(), topology, config, get_test_stream_ptr(), false); + network->set_input_data("query", query_mem); + network->set_input_data("key", key_mem); + network->set_input_data("value", value_mem); + network->set_input_data("key_cache", key_cache_mem); + network->set_input_data("value_cache", value_cache_mem); + network->set_input_data("past_lens", past_lens_mem); + network->set_input_data("subsequence_begins", subsequence_begins_mem); + network->set_input_data("block_indices", block_indices_mem); + network->set_input_data("block_indices_begins", block_indices_begins_mem); + network->set_input_data("scale", scale_mem); + network->set_input_data("sliding_window", sliding_window_mem); + network->set_input_data("alibi", alibi_mem); + network->set_input_data("max_context_len", max_context_len_mem); + + auto outputs = network->execute(); + + cldnn::memory::ptr output_data_mem = nullptr; + cldnn::memory::ptr output_scores_mem = nullptr; + + output_data_mem = outputs.at("output_data").get_memory(); + if (p.scores_output) { + output_scores_mem = outputs.at("output_scores").get_memory(); + } + + auto ref_data = PagedAttentionReference(pam).get_reference(); + compare(output_data_mem, output_scores_mem, ref_data); + } + + void compare(memory::ptr data_output_mem, memory::ptr scores_output_mem, std::pair, std::vector> ref_data) { + if (data_output_mem) { + ASSERT_EQ(data_output_mem->count(), ref_data.first.size()); + mem_lock mem_ptr(data_output_mem, get_test_stream()); + for (size_t i = 0; i < data_output_mem->count(); i++) { + ASSERT_NEAR(mem_ptr[i], ref_data.first[i], tolerance); + } + } + + if (scores_output_mem) { + ASSERT_EQ(scores_output_mem->count(), ref_data.second.size()); + mem_lock mem_ptr(scores_output_mem, get_test_stream()); + for (size_t i = 0; i < scores_output_mem->count(); i++) { + ASSERT_NEAR(mem_ptr[i], ref_data.second[i], tolerance); + } + } + } +}; + +struct paged_attention_test_params { + std::vector subsequences; + int num_heads; + int head_size; + int block_size; + bool scores_output; +}; + +class paged_attention_test : public PagedAttentionTest {}; +TEST_P(paged_attention_test, basic) { + auto p = GetParam(); + + execute(p); +} + +INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing::ValuesIn(std::vector{ + /* with scores output */ + paged_attention_test_params{ {{10, 0}}, 2, 64, 16, true }, // 1st token + paged_attention_test_params{ {{36, 0}}, 2, 64, 16, true }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, true }, // 1st token long + paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 64, 16, true }, // 1st token + 1st token + paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 64, 16, true }, // 1st token + 1st token + paged_attention_test_params{ {{1, 10}}, 2, 64, 16, true }, // 2nd token + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, true }, // 2nd token + 2nd token + paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, true }, // mixed: 2nd token + 1st token + part of 1st token + /* without scores output */ + paged_attention_test_params{ {{10, 0}}, 2, 64, 16, false }, // 1st token + paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, false }, // 1st token long + paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, false }, // 2nd token + 2nd token +}));