diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl index cddafe62623d9e..140fb0b9263748 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl @@ -172,7 +172,7 @@ KERNEL(sdpa_opt)( #endif #if SUBGROUPS_PER_WG > SUBGROUP_SIZE - #error "sdpa_opt.cl: Number of subgroups per work group should be less than subgroup_size + #error "sdpa_opt.cl: Number of subgroups per work group should be no more than subgroup_size" #endif const uint sgid = get_sub_group_id(); @@ -876,29 +876,30 @@ KERNEL(sdpa_opt)( __local INPUT0_TYPE slm_query[HEAD_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE]; // SLM buffer for intermediate QK results - __local OUTPUT_TYPE slm_qk_vals[SEQ_LEN_PARTITION_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE]; + __local OUTPUT_TYPE slm_qk_vals[TARGET_SEQ_LEN_BLOCK_SIZE][SEQ_LEN_PARTITION_SIZE]; // SLM buffers for SoftMax calculation and qk_max/qk_sums results aggregation across all WGs - __local SOFTMAX_ACCUMULATOR_TYPE slm_qk_max_vals[SUBGROUPS_PER_WG * TARGET_SEQ_LEN_BLOCK_SIZE]; - __local SOFTMAX_ACCUMULATOR_TYPE slm_exp_sum_vals[SUBGROUPS_PER_WG * TARGET_SEQ_LEN_BLOCK_SIZE]; + __local SOFTMAX_ACCUMULATOR_TYPE slm_qk_max_vals[TARGET_SEQ_LEN_BLOCK_SIZE][SUBGROUPS_PER_WG]; // SLM buffers for SoftMax recalculation for current iteration based on the previous results - __local SOFTMAX_ACCUMULATOR_TYPE slm_exp_sum_cur[TARGET_SEQ_LEN_BLOCK_SIZE]; - __local SOFTMAX_ACCUMULATOR_TYPE slm_max_val_cur[TARGET_SEQ_LEN_BLOCK_SIZE]; __local SOFTMAX_ACCUMULATOR_TYPE slm_exp_sum_prev[TARGET_SEQ_LEN_BLOCK_SIZE]; __local SOFTMAX_ACCUMULATOR_TYPE slm_max_val_prev[TARGET_SEQ_LEN_BLOCK_SIZE]; + __local SOFTMAX_ACCUMULATOR_TYPE slm_update_factor[TARGET_SEQ_LEN_BLOCK_SIZE]; +#if IS_PAGED_ATTENTION + const uint block_start_pos = blocked_indexes_start[target_seq_dim]; + const uint block_end_pos = blocked_indexes_end[target_seq_dim]; + const uint seq_idx_end = block_end_pos - block_start_pos; +#else + const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); +#endif { // Load Q input to SLM and transpose it #if IS_PAGED_ATTENTION - const uint block_start_pos = blocked_indexes_start[target_seq_dim]; - const uint block_end_pos = blocked_indexes_end[target_seq_dim]; uint query_offset = INPUT0_OFFSET + block_start_pos * (HEAD_SIZE * NUM_HEADS + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM) + num_heads_dim * HEAD_SIZE + head_size_idx; const uint query_pitch = (HEAD_SIZE * NUM_HEADS + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM); - - const uint cur_target_seq_len_size = block_end_pos - block_start_pos; #else #ifdef INPUT0_DIMS_ORDER uint query_offset = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, target_seq_idx, (head_size_idx)); @@ -908,7 +909,6 @@ KERNEL(sdpa_opt)( uint query_offset = INPUT0_GET_INDEX(b0_idx, b1_idx, target_seq_idx, (head_size_idx)); const uint query_pitch = HEAD_SIZE; #endif - const uint cur_target_seq_len_size = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); #endif uint query_local_offset = head_size_idx * TARGET_SEQ_LEN_BLOCK_SIZE; @@ -922,9 +922,9 @@ KERNEL(sdpa_opt)( const INPUT0_TYPE scale_val = INPUT0_VAL_ONE; #endif - if (cur_target_seq_len_size != TARGET_SEQ_LEN_BLOCK_SIZE) { + if (seq_idx_end != TARGET_SEQ_LEN_BLOCK_SIZE) { if (sgid * SUBGROUP_SIZE < HEAD_SIZE) { - for (uint seq_idx = 0; seq_idx < cur_target_seq_len_size; seq_idx++) { + for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { INPUT0_TYPE val = BLOCK_READN(INPUT0_TYPE, 1, query_input, query_offset); slm_query[query_local_offset] = val * scale_val; @@ -993,8 +993,6 @@ KERNEL(sdpa_opt)( __attribute__((opencl_unroll_hint(1))) for (uint start_partition_idx = 0; start_partition_idx < SOURCE_SEQ_LEN; start_partition_idx += SEQ_LEN_PARTITION_SIZE) { - SOFTMAX_ACCUMULATOR_TYPE qk_max = SOFTMAX_ACCUMULATOR_VAL_MIN; - const uint seq_len = start_partition_idx + sgid * SUBGROUP_SIZE; const uint partition_seq_len = min((uint)SOURCE_SEQ_LEN - start_partition_idx, (uint)SEQ_LEN_PARTITION_SIZE); @@ -1034,7 +1032,7 @@ KERNEL(sdpa_opt)( b0_idx, b1_idx, #if IS_PAGED_ATTENTION - blocked_indexes_start[target_seq_dim] - subsequence_begins[gws_seq_indexes_correspondence[target_seq_dim]] + sglid, + block_start_pos - subsequence_begins[gws_seq_indexes_correspondence[target_seq_dim]] + sglid, #else target_seq_idx + sglid, #endif @@ -1157,6 +1155,7 @@ KERNEL(sdpa_opt)( } { + SOFTMAX_ACCUMULATOR_TYPE qk_max = SOFTMAX_ACCUMULATOR_VAL_MIN; unroll_for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) { #if !APPLY_SCALES_TO_QUERY #if HAS_SCALE_INPUT @@ -1175,93 +1174,86 @@ KERNEL(sdpa_opt)( qk_acc[i] = INPUT0_MIN_FUNC(INPUT0_MAX_FUNC(qk_acc[i], INPUT0_VAL_MIN), INPUT0_VAL_MAX); qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(qk_acc[i])); + slm_qk_vals[sglid][sgid * TARGET_SEQ_LEN_BLOCK_SIZE + i] = qk_acc[i]; } - } - - { - slm_qk_max_vals[sgid * SUBGROUP_SIZE + sglid] = qk_max; - qk_max = SOFTMAX_ACCUMULATOR_VAL_MIN; + slm_qk_max_vals[sglid][sgid] = qk_max; } barrier(CLK_LOCAL_MEM_FENCE); { // SoftMax calculation - SOFTMAX_ACCUMULATOR_TYPE qk_max_new = SOFTMAX_ACCUMULATOR_VAL_MIN; - - for (uint i = 0; i < SUBGROUPS_PER_WG; i++) { - SOFTMAX_ACCUMULATOR_TYPE qk_max_val = slm_qk_max_vals[i * SUBGROUP_SIZE + sglid]; - qk_max_new = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max_new, qk_max_val); - } - - if (sgid == 0) { - slm_max_val_cur[sglid] = qk_max_new; - } - - SOFTMAX_ACCUMULATOR_TYPE exp_sum_new = SOFTMAX_ACCUMULATOR_VAL_ZERO; - - for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) { - qk_acc[i] = native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(qk_acc[i]) - qk_max_new); - exp_sum_new += qk_acc[i]; - } - - { - slm_exp_sum_vals[sgid * SUBGROUP_SIZE + sglid] = exp_sum_new; - } - - exp_sum_new = SOFTMAX_ACCUMULATOR_VAL_ZERO; - - barrier(CLK_LOCAL_MEM_FENCE); - - for (uint i = 0; i < SUBGROUPS_PER_WG; i++) { - SOFTMAX_ACCUMULATOR_TYPE exp_sum = slm_exp_sum_vals[i * SUBGROUP_SIZE + sglid]; - exp_sum_new += exp_sum; - } + // each sg will compute a whole row of query + uint aligned_width = ((SUBGROUPS_PER_WG + (SUBGROUP_SIZE-1)) & ~(SUBGROUP_SIZE-1)); + for (uint m = sgid; m < seq_idx_end; m += SUBGROUPS_PER_WG) { + // rowmax + SOFTMAX_ACCUMULATOR_TYPE qk_max_new, qk_max_cur = SOFTMAX_ACCUMULATOR_VAL_MIN; + for (uint k = sglid; k < aligned_width; k += SUBGROUP_SIZE) { + if (k < SUBGROUPS_PER_WG) { + qk_max_new = slm_qk_max_vals[m][k]; + } else { + qk_max_new = SOFTMAX_ACCUMULATOR_VAL_MIN; + } + qk_max_new = SOFTMAX_ACCUMULATOR_MAX_FUNC(sub_group_reduce_max(qk_max_new), qk_max_cur); + qk_max_cur = qk_max_new; + } - for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) { - qk_acc[i] = qk_acc[i] / exp_sum_new; - } + SOFTMAX_ACCUMULATOR_TYPE max_val_prev = slm_max_val_prev[m]; + qk_max_new = SOFTMAX_ACCUMULATOR_MAX_FUNC(sub_group_reduce_max(qk_max_cur), max_val_prev); - if (sgid == 0) { - slm_exp_sum_cur[sglid] = exp_sum_new; - } - - for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) { - slm_qk_vals[sglid * SEQ_LEN_PARTITION_SIZE + sgid * TARGET_SEQ_LEN_BLOCK_SIZE + i] = qk_acc[i]; - } + // softmax + SOFTMAX_ACCUMULATOR_TYPE exp_sum_new = SOFTMAX_ACCUMULATOR_VAL_ZERO; + for (uint k = sglid; k < partition_seq_len; k += SUBGROUP_SIZE) { + SOFTMAX_ACCUMULATOR_TYPE a = native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[m][k]) - qk_max_new); + slm_qk_vals[m][k] = TO_OUTPUT_TYPE(a); + exp_sum_new += a; + } + exp_sum_new = sub_group_reduce_add(exp_sum_new); #if PAGED_ATTENTION_SCORES_OUTPUT - const uint subsequence_idx = gws_seq_indexes_correspondence[target_seq_dim]; - const uint subsequence_end_pos = subsequence_begins[subsequence_idx + 1]; - const uint block_start_pos = blocked_indexes_start[target_seq_dim]; - const uint block_end_pos = blocked_indexes_end[target_seq_dim]; - - // PagedAttention is supposed to save only last "row" of the QK matrix multiplication, - // so save SEQ_LEN_PARTITION_SIZE elements for each partition - if (subsequence_end_pos == block_end_pos) { - const uint last_row_idx = block_end_pos - block_start_pos - 1; - if (sglid == last_row_idx) { - const uint partition_idx = start_partition_idx / SEQ_LEN_PARTITION_SIZE; - - if (sgid == 0) { - const uint max_partitions_num = aligned_max_context_len / SEQ_LEN_PARTITION_SIZE; - const uint exp_sums_output_offset = subsequence_idx * NUM_HEADS * max_partitions_num + - num_heads_dim * max_partitions_num + - partition_idx; - exp_sums[exp_sums_output_offset] = exp_sum_new; - max_logits[exp_sums_output_offset] = qk_max_new; - } + const uint subsequence_idx = gws_seq_indexes_correspondence[target_seq_dim]; + const uint subsequence_end_pos = subsequence_begins[subsequence_idx + 1]; + + // PagedAttention is supposed to save only last "row" of the QK matrix multiplication, + // so save SEQ_LEN_PARTITION_SIZE elements for each partition + if (subsequence_end_pos == block_end_pos) { + const uint last_row_idx = block_end_pos - block_start_pos - 1; + if (m == last_row_idx) { + const uint partition_idx = start_partition_idx / SEQ_LEN_PARTITION_SIZE; + + SOFTMAX_ACCUMULATOR_TYPE correction_factor = native_exp(qk_max_new - qk_max_cur); + + if (sglid == 0) { + const uint max_partitions_num = aligned_max_context_len / SEQ_LEN_PARTITION_SIZE; + const uint exp_sums_output_offset = subsequence_idx * NUM_HEADS * max_partitions_num + + num_heads_dim * max_partitions_num + + partition_idx; + exp_sums[exp_sums_output_offset] = exp_sum_new * correction_factor; + max_logits[exp_sums_output_offset] = qk_max_cur; + } - const uint output_offset = subsequence_idx * NUM_HEADS * aligned_max_context_len + - num_heads_dim * aligned_max_context_len + - partition_idx * SEQ_LEN_PARTITION_SIZE + sgid * TARGET_SEQ_LEN_BLOCK_SIZE; - for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) { - softmax_results[output_offset + i] = qk_acc[i]; + const uint output_offset = subsequence_idx * NUM_HEADS * aligned_max_context_len + + num_heads_dim * aligned_max_context_len + + partition_idx * SEQ_LEN_PARTITION_SIZE; + for (uint i = sglid; i < partition_seq_len; i += SUBGROUP_SIZE) { + softmax_results[output_offset + i] = TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[m][i]) / exp_sum_new; + } } + } +#endif + + // update + if (sglid == 0) { + SOFTMAX_ACCUMULATOR_TYPE pre_exp_sum = slm_exp_sum_prev[m]; + SOFTMAX_ACCUMULATOR_TYPE correction_factor = native_exp(max_val_prev - qk_max_new); + SOFTMAX_ACCUMULATOR_TYPE pre_exp_sum_fixed = pre_exp_sum * correction_factor; + exp_sum_new += pre_exp_sum_fixed; + slm_update_factor[m] = correction_factor; + slm_max_val_prev[m] = qk_max_new; + slm_exp_sum_prev[m] = exp_sum_new; } } -#endif barrier(CLK_LOCAL_MEM_FENCE); } @@ -1311,7 +1303,7 @@ KERNEL(sdpa_opt)( MAKE_VECTOR_TYPE(OUTPUT_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_val; unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { - qk_val[seq_idx] = slm_qk_vals[seq_idx * SEQ_LEN_PARTITION_SIZE + seq_len + sglid]; + qk_val[seq_idx] = slm_qk_vals[seq_idx][seq_len + sglid]; } #if IS_KV_COMPRESSED @@ -1387,7 +1379,7 @@ KERNEL(sdpa_opt)( MAKE_VECTOR_TYPE(OUTPUT_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_val; unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { - qk_val[seq_idx] = slm_qk_vals[seq_idx * SEQ_LEN_PARTITION_SIZE + seq_len * SUBGROUP_SIZE + sglid]; + qk_val[seq_idx] = slm_qk_vals[seq_idx][seq_len * SUBGROUP_SIZE + sglid]; } unroll_for (uint i = 0; i < SUBGROUP_SIZE; i++) { @@ -1417,11 +1409,9 @@ KERNEL(sdpa_opt)( // QK*V leftovers processing const uint seq_len_leftovers_start = ((seq_len_end / SUBGROUP_SIZE) * SUBGROUP_SIZE); if (seq_len_leftovers_start != seq_len_end) { - uint qk_offset = min(seq_len_leftovers_start + sglid, seq_len_end - 1); MAKE_VECTOR_TYPE(OUTPUT_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_val; unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { - qk_val[seq_idx] = slm_qk_vals[qk_offset]; - qk_offset += SEQ_LEN_PARTITION_SIZE; + qk_val[seq_idx] = slm_qk_vals[seq_idx][seq_len_leftovers_start+sglid]; } #if IS_PAGED_ATTENTION #ifdef BROADCAST_GROUP_SIZE @@ -1484,41 +1474,17 @@ KERNEL(sdpa_opt)( } + // protect slm_qk_vals as it is read in w*v stage and write in next round q*k stage. + barrier(CLK_LOCAL_MEM_FENCE); + { // Rescale acc_output_res values and save current iter results to global accumulator - SOFTMAX_ACCUMULATOR_TYPE exp_sum_prev = slm_exp_sum_prev[sglid]; - SOFTMAX_ACCUMULATOR_TYPE exp_sum_cur = slm_exp_sum_cur[sglid]; - SOFTMAX_ACCUMULATOR_TYPE max_val_prev = slm_max_val_prev[sglid]; - SOFTMAX_ACCUMULATOR_TYPE max_val_cur = slm_max_val_cur[sglid]; - - barrier(CLK_LOCAL_MEM_FENCE); - -#if IS_PAGED_ATTENTION - const uint block_start_pos_new = blocked_indexes_start[target_seq_dim]; - const uint block_end_pos_new = blocked_indexes_end[target_seq_dim]; - const uint seq_idx_end = block_end_pos_new - block_start_pos_new; -#else - const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); -#endif - for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { - SOFTMAX_ACCUMULATOR_TYPE total_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(sub_group_broadcast(max_val_prev, seq_idx), sub_group_broadcast(max_val_cur, seq_idx)); - SOFTMAX_ACCUMULATOR_TYPE updated_exp_sum_prev = sub_group_broadcast(exp_sum_prev, seq_idx) * native_exp(sub_group_broadcast(max_val_prev, seq_idx) - total_max); - SOFTMAX_ACCUMULATOR_TYPE updated_exp_sum_cur = sub_group_broadcast(exp_sum_cur, seq_idx) * native_exp(sub_group_broadcast(max_val_cur, seq_idx) - total_max); - SOFTMAX_ACCUMULATOR_TYPE updated_total_exp_sum = updated_exp_sum_prev + updated_exp_sum_cur; - if (start_partition_idx > 0) { - OUTPUT_TYPE updated_prev_res = TO_SOFTMAX_ACCUMULATOR_TYPE(output_acc[seq_idx]) * updated_exp_sum_prev / updated_total_exp_sum;; - acc_output_res[seq_idx] *= updated_exp_sum_cur / updated_total_exp_sum; + OUTPUT_TYPE updated_prev_res = TO_SOFTMAX_ACCUMULATOR_TYPE(output_acc[seq_idx]) * slm_update_factor[seq_idx]; acc_output_res[seq_idx] += updated_prev_res; } - output_acc[seq_idx] = acc_output_res[seq_idx]; - - if (sgid == 0 && sglid == 0) { - slm_exp_sum_prev[seq_idx] = updated_total_exp_sum; - slm_max_val_prev[seq_idx] = total_max; - } } } } @@ -1528,7 +1494,7 @@ KERNEL(sdpa_opt)( if (sgid >= (SUBGROUPS_PER_WG / SG_SCALE_FACTOR)) { unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { - slm_qk_vals[seq_idx * SEQ_LEN_PARTITION_SIZE + (uint)get_local_id(2)] = output_acc[seq_idx]; + slm_qk_vals[seq_idx][(uint)get_local_id(2)] = output_acc[seq_idx]; } } @@ -1537,34 +1503,27 @@ KERNEL(sdpa_opt)( if (sgid < (SUBGROUPS_PER_WG / SG_SCALE_FACTOR)) { unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { unroll_for (uint i = 1; i < SG_SCALE_FACTOR; i++) { - output_acc[seq_idx] += slm_qk_vals[seq_idx * SEQ_LEN_PARTITION_SIZE + (i * HEAD_SIZE) + head_size_idx]; + output_acc[seq_idx] += slm_qk_vals[seq_idx][(i * HEAD_SIZE) + head_size_idx]; } } #if IS_PAGED_ATTENTION - const uint block_start_pos_new = blocked_indexes_start[target_seq_dim]; - const uint block_end_pos_new = blocked_indexes_end[target_seq_dim]; - - uint output_offset = block_start_pos_new * HEAD_SIZE * NUM_HEADS + num_heads_dim * HEAD_SIZE + sgid * SUBGROUP_SIZE; + uint output_offset = block_start_pos * HEAD_SIZE * NUM_HEADS + num_heads_dim * HEAD_SIZE + sgid * SUBGROUP_SIZE; const uint output_pitch = HEAD_SIZE * NUM_HEADS; #else uint output_offset = OUTPUT_GET_INDEX(b0_idx, b1_idx, target_seq_idx, sgid * SUBGROUP_SIZE); const uint output_pitch = HEAD_SIZE; #endif -#if IS_PAGED_ATTENTION - if (block_start_pos_new + TARGET_SEQ_LEN_BLOCK_SIZE != block_end_pos_new) { - const uint seq_idx_end = block_end_pos_new - block_start_pos_new; -#else - if (get_global_id(1) == get_global_size(1) - 1) { - const uint seq_idx_end = min((uint)TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); -#endif + if (TARGET_SEQ_LEN_BLOCK_SIZE > seq_idx_end) { for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { + output_acc[seq_idx] /= slm_exp_sum_prev[seq_idx]; OUTPUT_BLOCK_WRITE(output, output_offset, output_acc[seq_idx]); output_offset += output_pitch; } } else { unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) { + output_acc[seq_idx] /= slm_exp_sum_prev[seq_idx]; OUTPUT_BLOCK_WRITE(output, output_offset, output_acc[seq_idx]); output_offset += output_pitch; }