Skip to content

Commit

Permalink
[GPU] Add scores output support for PagedAttention (#28205)
Browse files Browse the repository at this point in the history
### Details:
 - Added scores output support for PagedAttention
 - Added PagedAttention unit tests

### Tickets:
- [CVS-153660](https://jira.devtools.intel.com/browse/CVS-153660)
  • Loading branch information
sshlyapn authored Dec 27, 2024
1 parent c040b7b commit e2ac535
Show file tree
Hide file tree
Showing 14 changed files with 1,428 additions and 161 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ struct paged_attention : public primitive_base<paged_attention> {
OPENVINO_ASSERT(inputs.size() == 13, "[GPU] Unexpected inputs number for PagedAttention primitive: ", inputs.size());
}

bool has_scores_output() const {
return num_outputs == 2;
}

bool operator==(const primitive& rhs) const override {
return compare_common_params(rhs);
}
Expand Down
279 changes: 208 additions & 71 deletions src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,11 @@
#include "intel_gpu/primitives/paged_attention.hpp"
#include "primitive_inst.h"

#include "sdpa/pa_sdpa_kernel_opt.h"

namespace cldnn {

enum PagedAttentionStage {
GENERATE = 0,
PREFILL = 1,
MIXED = 2,
UNKNOWN = 3
};
using PagedAttentionStage = kernel_selector::PagedAttentionStage;

PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param);

Expand Down Expand Up @@ -61,6 +58,9 @@ class typed_primitive_inst<paged_attention> : public typed_primitive_inst_base<p
memory::ptr block_indices_memory_ptr() const { return input_memory_ptr(7); }
memory::ptr block_indices_begins_memory_ptr() const { return input_memory_ptr(8); }
memory::ptr alibi_memory_ptr() const { return input_memory_ptr(11); }
memory::ptr rotated_block_indices_memory_ptr() const { return input_memory_ptr(13); }
memory::ptr rotation_deltas_memory_ptr() const { return input_memory_ptr(14); }
memory::ptr rotation_trig_lut_memory_ptr() const { return input_memory_ptr(15); }

std::shared_ptr<network> prefill_network;

Expand Down
87 changes: 75 additions & 12 deletions src/plugins/intel_gpu/src/graph/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,38 @@ layout paged_attention_inst::calc_output_layout(const paged_attention_node& /*no

template<typename ShapeType>
std::vector<layout> paged_attention_inst::calc_output_layouts(paged_attention_node const& /*node*/, kernel_impl_params const& impl_param) {
auto out_layout = impl_param.get_input_layout(0);
auto data_layout = impl_param.get_input_layout(0);

const auto& key_cache_ps = impl_param.get_input_layout(3).get_partial_shape();
bool valid_block_size = key_cache_ps[3].is_dynamic() || key_cache_ps[3].get_length() == paged_attention::block_size;
OPENVINO_ASSERT(valid_block_size, "[GPU] Incorrect block size for Paged Attention operation. "
"Expected ", paged_attention::block_size, ", but got ", key_cache_ps[3].get_length());

return {out_layout};
std::vector<layout> output_layouts{ data_layout };

const auto& desc = impl_param.typed_desc<paged_attention>();
if (desc->has_scores_output()) {
const auto past_lens_idx = 5;
const auto output_dt = data_layout.data_type;
if (impl_param.get_input_layout(past_lens_idx).is_static()) {
const auto& memory_deps = impl_param.memory_deps;
const auto past_lens_mem = memory_deps.at(past_lens_idx);
mem_lock<int32_t, mem_lock_type::read> past_lens_mem_lock(past_lens_mem, *impl_param.strm);

long int total_size = 0;
for (size_t i = 0; i < past_lens_mem_lock.size(); i++) {
total_size += past_lens_mem_lock[i];
}

total_size += static_cast<long int>(impl_param.get_input_layout(0).get_shape()[0]);

output_layouts.push_back(layout{ov::PartialShape{total_size}, output_dt, format::bfyx});
} else {
output_layouts.push_back(layout{ov::PartialShape::dynamic(1), output_dt, format::bfyx});
}
}

return output_layouts;
}

template std::vector<layout>
Expand All @@ -81,45 +105,79 @@ std::string paged_attention_inst::to_string(const paged_attention_node& node) {
}

void paged_attention_inst::on_execute() {
auto stage = get_paged_attention_stage(*_impl_params);
const auto& desc = _impl_params->typed_desc<paged_attention>();
const bool has_scores_output = desc->has_scores_output();
const auto stage = get_paged_attention_stage(*_impl_params);

if (stage == PagedAttentionStage::UNKNOWN ||
stage == PagedAttentionStage::GENERATE)
if ((stage == PagedAttentionStage::UNKNOWN) ||
(stage == PagedAttentionStage::GENERATE && !has_scores_output))
return;

auto& stream = get_network().get_stream();
const auto past_lens_mem = past_lens_memory_ptr();
const auto subsequence_begins_mem = subsequence_begins_memory_ptr();
mem_lock<int32_t, mem_lock_type::read> past_lens_mem_lock(past_lens_mem, stream);
mem_lock<int32_t, mem_lock_type::read> subsequence_begins_mem_lock(subsequence_begins_mem, stream);
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> subsequence_offsets_lock = nullptr;

if (has_scores_output) {
const size_t subsequence_offsets_idx = 4;

OPENVINO_ASSERT(_intermediates_memory.size() > subsequence_offsets_idx,
"[GPU] Unexpected number of intermediates buffers for Paged Attention for scores output calculation");

auto subsequence_offsets_mem = _intermediates_memory[subsequence_offsets_idx];
subsequence_offsets_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(subsequence_offsets_mem, stream));
}

if (stage == PagedAttentionStage::GENERATE) {
// For the generate stage it's not necessary to configure any other intermediate
// buffers. Simply calculate the offsets and exit
size_t subsequence_offsets_acc = 0;
for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) {
const auto past_len = past_lens_mem_lock[i];
const auto seq_start = subsequence_begins_mem_lock[i];
const auto seq_end = subsequence_begins_mem_lock[i + 1];
const auto seq_length = seq_end - seq_start;

if (subsequence_offsets_lock) {
subsequence_offsets_lock->operator[](i) = static_cast<int32_t>(subsequence_offsets_acc);
subsequence_offsets_acc += seq_length + past_len;
}
}

return;
}

OPENVINO_ASSERT(_intermediates_memory.size() >= 3, "Unexpected number of intermediates buffers for Paged Attention at prefill stage");

const auto blocks_indexes_start_idx = 0;
const auto blocks_indexes_end_idx = 1;
const auto blocked_gws_subseq_mapping_idx = 2;

const auto past_lens_mem = past_lens_memory_ptr();
auto subsequence_begins_mem = subsequence_begins_memory_ptr();
auto blocks_indexes_start_mem = _intermediates_memory[blocks_indexes_start_idx];
auto blocks_indexes_end_mem = _intermediates_memory[blocks_indexes_end_idx];
auto blocked_gws_subseq_mapping_mem = _intermediates_memory[blocked_gws_subseq_mapping_idx];

OPENVINO_ASSERT(subsequence_begins_mem->get_layout().data_type == data_types::i32);

auto& stream = get_network().get_stream();
mem_lock<int32_t, mem_lock_type::read> past_lens_mem_lock(past_lens_mem, stream);
mem_lock<int32_t, mem_lock_type::read> subsequence_begins_mem_lock(subsequence_begins_mem, stream);
mem_lock<int32_t, mem_lock_type::write> blocks_indexes_start_lock(blocks_indexes_start_mem, stream);
mem_lock<int32_t, mem_lock_type::write> blocks_indexes_end_lock(blocks_indexes_end_mem, stream);
mem_lock<int32_t, mem_lock_type::write> blocked_gws_subseq_mapping_mem_lock(blocked_gws_subseq_mapping_mem, stream);
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> sequential_gws_subseq_mapping_lock = nullptr;

if (stage == PagedAttentionStage::MIXED) {
const auto sequential_gws_subseq_mapping_idx = 6;
const size_t sequential_gws_subseq_mapping_idx = has_scores_output ? 8 : 6;

OPENVINO_ASSERT(_intermediates_memory.size() > sequential_gws_subseq_mapping_idx,
"Unexpected number of intermediates buffers for Paged Attention for mixed stage");
"[GPU] Unexpected number of intermediates buffers for Paged Attention for mixed stage");

auto sequential_gws_subseq_mapping_mem = _intermediates_memory[sequential_gws_subseq_mapping_idx];
sequential_gws_subseq_mapping_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(sequential_gws_subseq_mapping_mem, stream));
}

size_t index = 0;
size_t subsequence_offsets_acc = 0;
const auto target_seq_len_block_size = 16; // TODO: Get block size from the impl
for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) {
const auto past_len = past_lens_mem_lock[i];
Expand Down Expand Up @@ -159,6 +217,11 @@ void paged_attention_inst::on_execute() {
sequential_gws_subseq_mapping_lock->operator[](idx) = static_cast<int32_t>(i);
}
}

if (subsequence_offsets_lock) {
subsequence_offsets_lock->operator[](i) = static_cast<int32_t>(subsequence_offsets_acc);
subsequence_offsets_acc += seq_length + past_len;
}
}
}

Expand Down
182 changes: 182 additions & 0 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ KERNEL(pa_sdpa_opt)(
const __global ALIBI_INPUT_TYPE* alibi_slopes,
#endif
__global OUTPUT_TYPE* output,
#if PAGED_ATTENTION_SCORES_OUTPUT
__global SOFTMAX_ACCUMULATOR_TYPE* softmax_results,
const __global int* subsequence_offsets,
#endif
__global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
__global SOFTMAX_ACCUMULATOR_TYPE* max_logits,
__global OUTPUT_TYPE* tmp_out
Expand Down Expand Up @@ -276,6 +280,28 @@ KERNEL(pa_sdpa_opt)(
const uint max_logits_offset = exp_sums_offset;
max_logits[max_logits_offset] = qk_max;
}

#if PAGED_ATTENTION_SCORES_OUTPUT
#if MULTI_TOKENS_PROCESSING
const uint subsequence_idx = gws_subseq_mapping[seq_idx];
const uint subsequence_start_pos = subsequence_begins[subsequence_idx];
const uint subsequence_end_pos = subsequence_begins[subsequence_idx + 1];
const bool save_softmax_results = seq_idx == subsequence_end_pos - 1;
#else
const uint subsequence_idx = seq_idx;
const bool save_softmax_results = true;
#endif // MULTI_TOKENS_PROCESSING
// PagedAttention is supposed to save only last "row" of the QK matrix multiplication,
// so save SEQ_LEN_PARTITION_SIZE elements for each partition
if (save_softmax_results) {
const uint output_offset = subsequence_idx * HEADS_NUM * total_partitions_num * SEQ_LEN_PARTITION_SIZE +
head_num_idx * total_partitions_num * SEQ_LEN_PARTITION_SIZE +
partition_idx * SEQ_LEN_PARTITION_SIZE;
for (uint i = sgid * SUBGROUP_SIZE + sglid; i < SEQ_LEN_PARTITION_SIZE; i += SUBGROUPS_PER_WG * SUBGROUP_SIZE) {
softmax_results[output_offset + i] = slm_qk_vals[i];
}
}
#endif // PAGED_ATTENTION_SCORES_OUTPUT
}
}

Expand Down Expand Up @@ -370,6 +396,10 @@ KERNEL(pa_sdpa_finalization_stage)(
const __global INPUT6_TYPE* subsequence_begins,
#endif
__global OUTPUT_TYPE* output,
#if PAGED_ATTENTION_SCORES_OUTPUT
__global SOFTMAX_ACCUMULATOR_TYPE* softmax_results,
const __global int* subsequence_offsets,
#endif
const __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
const __global SOFTMAX_ACCUMULATOR_TYPE* max_logits,
const __global OUTPUT_TYPE* tmp_out,
Expand Down Expand Up @@ -500,3 +530,155 @@ KERNEL(pa_sdpa_finalization_stage)(
}

#endif

#ifdef SDPA_STAGE_2
#define MAX_PARTITIONS_NUM 128

REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
KERNEL(pa_sdpa_scores_calculation)(
const __global INPUT3_TYPE* past_lens,
const __global INPUT6_TYPE* subsequence_begins,
__global OUTPUT1_TYPE* scores_output,
const __global SOFTMAX_ACCUMULATOR_TYPE* softmax_output,
const __global int* subsequence_offsets,
const __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
const __global SOFTMAX_ACCUMULATOR_TYPE* max_logits,
const __global OUTPUT_TYPE* tmp_out,
const uint is_mixed_mode) {
const uint subsequence_idx = get_global_id(2);
const uint partition_global_idx = get_global_id(0);
const uint local_id = get_local_id(0);
const uint partition_idx = get_group_id(0);
const uint partition_size = get_local_size(0);
const uint max_seq_len = get_global_size(0);
const uint partitions_num = get_num_groups(0);
const uint sgid = get_sub_group_id();
const uint sgid_num = get_num_sub_groups();
const uint sglid = get_sub_group_local_id();

const int subsequence_begin = subsequence_begins[subsequence_idx];
const int subsequence_end = subsequence_begins[subsequence_idx + 1];
const uint seq_len = (subsequence_end - subsequence_begin) + past_lens[subsequence_idx];

const uint num_of_partitions = CEIL_DIV(seq_len, partition_size);

if (partition_idx >= num_of_partitions)
return;

__local SOFTMAX_ACCUMULATOR_TYPE slm_exp_sums[HEADS_NUM];
__local SOFTMAX_ACCUMULATOR_TYPE slm_global_exp_sum[HEADS_NUM];

SOFTMAX_ACCUMULATOR_TYPE total_score = SOFTMAX_ACCUMULATOR_VAL_ZERO;
if (seq_len <= partition_size) {
// If seq_len is less than the partition size, just reduce the results over the heads
for (uint head_idx = 0; head_idx < HEADS_NUM; head_idx++) {
const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx;
SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output[input_offset];
total_score += softmax_value;
}
} else if (seq_len <= partition_size * MAX_PARTITIONS_NUM) {
// Optimized version for longer prompts (up to partition_size * MAX_PARTITIONS_NUM, ~64K tokens)

// Depending on the previous kernel exp_sums and max_logits might have different structure:
// For ordinary 1st and 2nd token kernels, there is only a single entry per subsequence.
// However, for mixed mode execution, exp_sums and max_logits include information for all
// tokens of each subsequence, but only the last one is needed for score calculation.
const uint subsequence_pos = is_mixed_mode ? subsequence_end - 1 : subsequence_idx;

for (uint head_idx = sgid; head_idx < HEADS_NUM; head_idx += sgid_num) {
SOFTMAX_ACCUMULATOR_TYPE max_logit[MAX_PARTITIONS_NUM / SUBGROUP_SIZE];
SOFTMAX_ACCUMULATOR_TYPE exp_sum[MAX_PARTITIONS_NUM / SUBGROUP_SIZE];

const uint exp_sums_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num;
for (int i = 0; i < partitions_num / SUBGROUP_SIZE; i++) {
max_logit[i] = max_logits[exp_sums_offset + i * SUBGROUP_SIZE + sglid];
exp_sum[i] = exp_sums[exp_sums_offset + i * SUBGROUP_SIZE + sglid];
}

const uint partitions_leftovers = partitions_num % SUBGROUP_SIZE;
if (partitions_leftovers != 0) {
const uint idx = partitions_num / SUBGROUP_SIZE;
max_logit[idx] = sglid >= partitions_leftovers ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits[exp_sums_offset + idx * SUBGROUP_SIZE + sglid];
exp_sum[idx] = sglid >= partitions_leftovers ? SOFTMAX_ACCUMULATOR_VAL_ZERO : exp_sums[exp_sums_offset + idx * SUBGROUP_SIZE + sglid];
}

SOFTMAX_ACCUMULATOR_TYPE global_max_logit = max_logit[0];
for (uint i = 1; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) {
global_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(global_max_logit, max_logit[i]);
}

global_max_logit = sub_group_reduce_max(global_max_logit);

SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
for (uint i = 0; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) {
SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = exp_sum[i] * native_exp(max_logit[i] - global_max_logit);
// slm_exp_sums[head_idx][i * SUBGROUP_SIZE + sglid] = adjusted_exp_sum;
if (i * SUBGROUP_SIZE + sglid == partition_idx)
slm_exp_sums[head_idx] = adjusted_exp_sum;
global_exp_sum += adjusted_exp_sum;
}

global_exp_sum = sub_group_reduce_add(global_exp_sum);

slm_global_exp_sum[head_idx] = global_exp_sum;
}

barrier(CLK_LOCAL_MEM_FENCE);

for (uint head_idx = 0; head_idx < HEADS_NUM; head_idx++) {
SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = slm_exp_sums[head_idx];
SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = slm_global_exp_sum[head_idx];

const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx;
SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output[input_offset];

softmax_value = softmax_value * adjusted_exp_sum / global_exp_sum;
total_score += softmax_value;
}
} else {
// Non optimized fallback version
const uint subsequence_pos = is_mixed_mode ? subsequence_end - 1 : subsequence_idx;
for (uint head_idx = 0; head_idx < HEADS_NUM; head_idx++) {
SOFTMAX_ACCUMULATOR_TYPE global_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN;
const uint max_logits_base_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num;
for (uint i = 0; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) {
const uint partition_offset = i * SUBGROUP_SIZE + sglid;
SOFTMAX_ACCUMULATOR_TYPE max_logit = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits[max_logits_base_offset + partition_offset];
global_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(global_max_logit, max_logit);
}

global_max_logit = sub_group_reduce_max(global_max_logit);

SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
SOFTMAX_ACCUMULATOR_TYPE partition_adjusted_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
const uint exp_sums_base_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num;
for (uint i = 0; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) {
const uint partition_offset = i * SUBGROUP_SIZE + sglid;
SOFTMAX_ACCUMULATOR_TYPE exp_sum = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_ZERO : exp_sums[exp_sums_base_offset + partition_offset];
SOFTMAX_ACCUMULATOR_TYPE max_logit = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits[max_logits_base_offset + partition_offset];
SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = exp_sum * native_exp(max_logit - global_max_logit);
global_exp_sum += adjusted_exp_sum;

// Save and broadcast the adjusted exp_sum for the currently being processed partition
if (i == partition_idx / SUBGROUP_SIZE)
partition_adjusted_exp_sum = sub_group_broadcast(adjusted_exp_sum, partition_idx % SUBGROUP_SIZE);
}

global_exp_sum = sub_group_reduce_add(global_exp_sum);

const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx;
SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output[input_offset];

softmax_value = softmax_value * partition_adjusted_exp_sum / global_exp_sum;
total_score += softmax_value;
}
}

const uint output_offset = subsequence_offsets[subsequence_idx];
if (partition_global_idx < seq_len) {
scores_output[output_offset + partition_global_idx] = total_score;
}
}

#undef MAX_PARTITIONS_NUM
#endif
Loading

0 comments on commit e2ac535

Please sign in to comment.