Skip to content

Commit

Permalink
fix for PAGED_ATTENTION_SCORES_OUTPUT.
Browse files Browse the repository at this point in the history
  • Loading branch information
ceciliapeng2011 committed Jan 3, 2025
1 parent 36b4668 commit 620a51f
Showing 1 changed file with 41 additions and 38 deletions.
79 changes: 41 additions & 38 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Original file line number Diff line number Diff line change
Expand Up @@ -1187,18 +1187,20 @@ KERNEL(sdpa_opt)(
uint aligned_width = ((SUBGROUPS_PER_WG + (SUBGROUP_SIZE-1)) & ~(SUBGROUP_SIZE-1));
for (uint m = sgid; m < seq_idx_end; m += SUBGROUPS_PER_WG) {
// rowmax
SOFTMAX_ACCUMULATOR_TYPE max_val_prev = slm_max_val_prev[m];
SOFTMAX_ACCUMULATOR_TYPE qk_max_new, qk_max_last = max_val_prev;
SOFTMAX_ACCUMULATOR_TYPE qk_max_new, qk_max_cur = SOFTMAX_ACCUMULATOR_VAL_MIN;
for (uint k = sglid; k < aligned_width; k += SUBGROUP_SIZE) {
if (k < SUBGROUPS_PER_WG) {
qk_max_new = slm_qk_max_vals[m][k];
} else {
qk_max_new = SOFTMAX_ACCUMULATOR_VAL_MIN;
}
qk_max_new = SOFTMAX_ACCUMULATOR_MAX_FUNC(sub_group_reduce_max(qk_max_new), qk_max_last);
qk_max_last = qk_max_new;
qk_max_new = SOFTMAX_ACCUMULATOR_MAX_FUNC(sub_group_reduce_max(qk_max_new), qk_max_cur);
qk_max_cur = qk_max_new;
}

SOFTMAX_ACCUMULATOR_TYPE max_val_prev = slm_max_val_prev[m];
qk_max_new = SOFTMAX_ACCUMULATOR_MAX_FUNC(sub_group_reduce_max(qk_max_cur), max_val_prev);

// softmax
SOFTMAX_ACCUMULATOR_TYPE exp_sum_new = SOFTMAX_ACCUMULATOR_VAL_ZERO;
for (uint k = sglid; k < partition_seq_len; k += SUBGROUP_SIZE) {
Expand All @@ -1208,11 +1210,43 @@ KERNEL(sdpa_opt)(
}
exp_sum_new = sub_group_reduce_add(exp_sum_new);

#if PAGED_ATTENTION_SCORES_OUTPUT
const uint subsequence_idx = gws_seq_indexes_correspondence[target_seq_dim];
const uint subsequence_end_pos = subsequence_begins[subsequence_idx + 1];

// PagedAttention is supposed to save only last "row" of the QK matrix multiplication,
// so save SEQ_LEN_PARTITION_SIZE elements for each partition
if (subsequence_end_pos == block_end_pos) {
const uint last_row_idx = block_end_pos - block_start_pos - 1;
if (m == last_row_idx) {
const uint partition_idx = start_partition_idx / SEQ_LEN_PARTITION_SIZE;

SOFTMAX_ACCUMULATOR_TYPE correction_factor = native_exp(qk_max_new - qk_max_cur);

if (sglid == 0) {
const uint max_partitions_num = aligned_max_context_len / SEQ_LEN_PARTITION_SIZE;
const uint exp_sums_output_offset = subsequence_idx * NUM_HEADS * max_partitions_num +
num_heads_dim * max_partitions_num +
partition_idx;
exp_sums[exp_sums_output_offset] = exp_sum_new * correction_factor;
max_logits[exp_sums_output_offset] = qk_max_cur;
}

const uint output_offset = subsequence_idx * NUM_HEADS * aligned_max_context_len +
num_heads_dim * aligned_max_context_len +
partition_idx * SEQ_LEN_PARTITION_SIZE;
for (uint i = sglid; i < partition_seq_len; i += SUBGROUP_SIZE) {
softmax_results[output_offset + i] = TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[m][i]) / exp_sum_new;
}
}
}
#endif

// update
if (sglid == 0) {
float pre_exp_sum = slm_exp_sum_prev[m];
float correction_factor = native_exp(max_val_prev - qk_max_new);
float pre_exp_sum_fixed = pre_exp_sum * correction_factor;
SOFTMAX_ACCUMULATOR_TYPE pre_exp_sum = slm_exp_sum_prev[m];
SOFTMAX_ACCUMULATOR_TYPE correction_factor = native_exp(max_val_prev - qk_max_new);
SOFTMAX_ACCUMULATOR_TYPE pre_exp_sum_fixed = pre_exp_sum * correction_factor;
exp_sum_new += pre_exp_sum_fixed;

slm_update_factor[m] = correction_factor;
Expand All @@ -1221,37 +1255,6 @@ KERNEL(sdpa_opt)(
}
}

#if PAGED_ATTENTION_SCORES_OUTPUT
const uint subsequence_idx = gws_seq_indexes_correspondence[target_seq_dim];
const uint subsequence_end_pos = subsequence_begins[subsequence_idx + 1];

// PagedAttention is supposed to save only last "row" of the QK matrix multiplication,
// so save SEQ_LEN_PARTITION_SIZE elements for each partition
if (subsequence_end_pos == block_end_pos) {
const uint last_row_idx = block_end_pos - block_start_pos - 1;
if (sglid == last_row_idx) {
const uint partition_idx = start_partition_idx / SEQ_LEN_PARTITION_SIZE;

if (sgid == 0) {
const uint max_partitions_num = aligned_max_context_len / SEQ_LEN_PARTITION_SIZE;
const uint exp_sums_output_offset = subsequence_idx * NUM_HEADS * max_partitions_num +
num_heads_dim * max_partitions_num +
partition_idx;
exp_sums[exp_sums_output_offset] = exp_sum_new;
max_logits[exp_sums_output_offset] = qk_max_new;
}

const uint output_offset = subsequence_idx * NUM_HEADS * aligned_max_context_len +
num_heads_dim * aligned_max_context_len +
partition_idx * SEQ_LEN_PARTITION_SIZE + sgid * TARGET_SEQ_LEN_BLOCK_SIZE;
for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) {
softmax_results[output_offset + i] = qk_acc[i];
}

}
}
#endif

barrier(CLK_LOCAL_MEM_FENCE);
}

Expand Down

0 comments on commit 620a51f

Please sign in to comment.