diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl index 948bd3c0f1a305..168e885f6f26a8 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl @@ -1292,185 +1292,174 @@ KERNEL(sdpa_opt_finalization_stage)( const uint b0_idx = batch_idx / NUM_HEADS; const uint b1_idx = batch_idx % NUM_HEADS; const uint target_seq_idx = get_global_id(1); + const uint local_id = get_local_id(2); + const uint sgid = get_sub_group_id(); const uint sglid = get_sub_group_local_id(); + const uint offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) + + b1_idx * (TARGET_SEQ_LEN * num_of_partitions) + + target_seq_idx * (num_of_partitions); + __global SOFTMAX_ACCUMULATOR_TYPE* cur_exp_sums = exp_sums + offset; + __global SOFTMAX_ACCUMULATOR_TYPE* cur_max_logits = max_logits + offset; + __local SOFTMAX_ACCUMULATOR_TYPE tmp_slm[HEAD_SIZE / SUBGROUP_SIZE]; + if (num_of_partitions <= SUBGROUP_SIZE * REG_VERSION_MAX_VALUES_PER_WI_LOWER) { /* Registers kernel version, can handle up to SEQ_LEN_PARTITION_SIZE(256) * SUBGROUP_SIZE(16) * REG_VERSION_MAX_VALUES_PER_WI_LOWER(8/16) = 32768/65536 tokens */ - SOFTMAX_ACCUMULATOR_TYPE exp_sum[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_ZERO}; - SOFTMAX_ACCUMULATOR_TYPE max_logit[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_MIN}; - SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO; - SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN; - const uint iters_num = CEIL_DIV(num_of_partitions, SUBGROUP_SIZE); - for (uint i = 0; i < iters_num; i++) { - const uint partition_idx = i * SUBGROUP_SIZE + sglid; - const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) + - b1_idx * (TARGET_SEQ_LEN * num_of_partitions) + - target_seq_idx * (num_of_partitions) + - partition_idx; - const uint max_logit_offset = exp_sums_offset; - - if (partition_idx < num_of_partitions) { - exp_sum[i] = exp_sums[exp_sums_offset]; - max_logit[i] = max_logits[max_logit_offset]; - local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logit[i]); - } + SOFTMAX_ACCUMULATOR_TYPE max_logits_u_exp_sum_slm[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_MIN}; + SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN; + for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) { + max_logits_u_exp_sum_slm[i] = cur_max_logits[i]; + local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logits_u_exp_sum_slm[i]); } + local_max_logit = sub_group_reduce_max(local_max_logit); + if (sglid == 0) { + tmp_slm[sgid] = local_max_logit; + } + barrier(CLK_LOCAL_MEM_FENCE); - SOFTMAX_ACCUMULATOR_TYPE global_max = sub_group_reduce_max(local_max_logit); + if (sglid < HEAD_SIZE / SUBGROUP_SIZE) { + local_max_logit = tmp_slm[sglid]; + } + local_max_logit = sub_group_reduce_max(local_max_logit); // Update exp_sum with respect to the global maximum - for (uint i = 0; i < iters_num; i++) { - const uint partition_idx = i * SUBGROUP_SIZE + sglid; - if (partition_idx < num_of_partitions) { - exp_sum[i] = exp_sum[i] * native_exp(max_logit[i] - global_max); - local_exp_sum += exp_sum[i]; - } + SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO; + for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) { + SOFTMAX_ACCUMULATOR_TYPE exp_sum_new = cur_exp_sums[i] * native_exp(max_logits_u_exp_sum_slm[i] - local_max_logit); + max_logits_u_exp_sum_slm[i] = exp_sum_new; + local_exp_sum += exp_sum_new; + } + local_exp_sum = sub_group_reduce_add(local_exp_sum); + if (sglid == 0) { + tmp_slm[sgid] = local_exp_sum; + } + barrier(CLK_LOCAL_MEM_FENCE); + local_exp_sum = 0; + if (sglid < HEAD_SIZE / SUBGROUP_SIZE) { + local_exp_sum = tmp_slm[sglid]; } - SOFTMAX_ACCUMULATOR_TYPE global_sum = sub_group_reduce_add(local_exp_sum); - - for (uint head_size_idx = 0; head_size_idx < HEAD_SIZE / SUBGROUP_SIZE; head_size_idx++) { - SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f; - for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) { + local_exp_sum = sub_group_reduce_add(local_exp_sum); + SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f; + for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) { const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + target_seq_idx * (num_of_partitions * HEAD_SIZE) + - partition_idx * (HEAD_SIZE) + - (head_size_idx * SUBGROUP_SIZE + sglid); + partition_idx * (HEAD_SIZE) + local_id; + // (head_size_idx * SUBGROUP_SIZE + sglid); OUTPUT_TYPE out_val = tmp_out[tmp_out_offset]; - acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) * - TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(exp_sum[partition_idx / SUBGROUP_SIZE], partition_idx % SUBGROUP_SIZE)) / - TO_SOFTMAX_ACCUMULATOR_TYPE(global_sum); - } - const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) + - b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) + - target_seq_idx * (HEAD_SIZE) + - (head_size_idx * SUBGROUP_SIZE + sglid); - - output[out_offset] = TO_OUTPUT_TYPE(acc); + acc += TO_SOFTMAX_ACCUMULATOR_TYPE(tmp_out[tmp_out_offset]) * TO_SOFTMAX_ACCUMULATOR_TYPE(max_logits_u_exp_sum_slm[partition_idx]); } + const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) + + b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) + + target_seq_idx * (HEAD_SIZE) + + local_id; + + output[out_offset] = TO_OUTPUT_TYPE(acc) / TO_OUTPUT_TYPE(local_exp_sum); } else if (num_of_partitions <= SUBGROUP_SIZE * REG_VERSION_MAX_VALUES_PER_WI) { /* Registers kernel version, can handle up to SEQ_LEN_PARTITION_SIZE(256) * SUBGROUP_SIZE(16) * REG_VERSION_MAX_VALUES_PER_WI(24/48) = 98304/196608 tokens */ - SOFTMAX_ACCUMULATOR_TYPE exp_sum[REG_VERSION_MAX_VALUES_PER_WI] = {SOFTMAX_ACCUMULATOR_VAL_ZERO}; - SOFTMAX_ACCUMULATOR_TYPE max_logit[REG_VERSION_MAX_VALUES_PER_WI] = {SOFTMAX_ACCUMULATOR_VAL_MIN}; - SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO; + SOFTMAX_ACCUMULATOR_TYPE max_logits_u_exp_sum_slm[REG_VERSION_MAX_VALUES_PER_WI] = {SOFTMAX_ACCUMULATOR_VAL_MIN}; SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN; - - const uint iters_num = CEIL_DIV(num_of_partitions, SUBGROUP_SIZE); - for (uint i = 0; i < iters_num; i++) { - const uint partition_idx = i * SUBGROUP_SIZE + sglid; - const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) + - b1_idx * (TARGET_SEQ_LEN * num_of_partitions) + - target_seq_idx * (num_of_partitions) + - partition_idx; - const uint max_logit_offset = exp_sums_offset; - - if (partition_idx < num_of_partitions) { - exp_sum[i] = exp_sums[exp_sums_offset]; - max_logit[i] = max_logits[max_logit_offset]; - local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logit[i]); - } + for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) { + max_logits_u_exp_sum_slm[i] = cur_max_logits[i]; + local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logits_u_exp_sum_slm[i]); + } + local_max_logit = sub_group_reduce_max(local_max_logit); + if (sglid == 0) { + tmp_slm[sgid] = local_max_logit; } + barrier(CLK_LOCAL_MEM_FENCE); - SOFTMAX_ACCUMULATOR_TYPE global_max = sub_group_reduce_max(local_max_logit); + if (sglid < HEAD_SIZE / SUBGROUP_SIZE) { + local_max_logit = tmp_slm[sglid]; + } + local_max_logit = sub_group_reduce_max(local_max_logit); // Update exp_sum with respect to the global maximum - for (uint i = 0; i < iters_num; i++) { - const uint partition_idx = i * SUBGROUP_SIZE + sglid; - if (partition_idx < num_of_partitions) { - exp_sum[i] = exp_sum[i] * native_exp(max_logit[i] - global_max); - local_exp_sum += exp_sum[i]; - } + // SOFTMAX_ACCUMULATOR_TYPE exp_sum[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_ZERO}; + SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO; + for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) { + SOFTMAX_ACCUMULATOR_TYPE exp_sum_new = cur_exp_sums[i] * native_exp(max_logits_u_exp_sum_slm[i] - local_max_logit); + max_logits_u_exp_sum_slm[i] = exp_sum_new; + local_exp_sum += exp_sum_new; + } + local_exp_sum = sub_group_reduce_add(local_exp_sum); + if (sglid == 0) { + tmp_slm[sgid] = local_exp_sum; + } + barrier(CLK_LOCAL_MEM_FENCE); + local_exp_sum = 0; + if (sglid < HEAD_SIZE / SUBGROUP_SIZE) { + local_exp_sum = tmp_slm[sglid]; } - SOFTMAX_ACCUMULATOR_TYPE global_sum = sub_group_reduce_add(local_exp_sum); - - for (uint head_size_idx = 0; head_size_idx < HEAD_SIZE / SUBGROUP_SIZE; head_size_idx++) { - SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f; - for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) { + local_exp_sum = sub_group_reduce_add(local_exp_sum); + SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f; + for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) { const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + target_seq_idx * (num_of_partitions * HEAD_SIZE) + - partition_idx * (HEAD_SIZE) + - (head_size_idx * SUBGROUP_SIZE + sglid); + partition_idx * (HEAD_SIZE) + local_id; OUTPUT_TYPE out_val = tmp_out[tmp_out_offset]; - acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) * - TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(exp_sum[partition_idx / SUBGROUP_SIZE], partition_idx % SUBGROUP_SIZE)) / - TO_SOFTMAX_ACCUMULATOR_TYPE(global_sum); - } - const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) + - b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) + - target_seq_idx * (HEAD_SIZE) + - (head_size_idx * SUBGROUP_SIZE + sglid); - - output[out_offset] = TO_OUTPUT_TYPE(acc); + acc += TO_SOFTMAX_ACCUMULATOR_TYPE(tmp_out[tmp_out_offset]) * TO_SOFTMAX_ACCUMULATOR_TYPE(max_logits_u_exp_sum_slm[partition_idx]); } + const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) + + b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) + + target_seq_idx * (HEAD_SIZE) + + local_id; + + output[out_offset] = TO_OUTPUT_TYPE(acc) / TO_OUTPUT_TYPE(local_exp_sum); } else { /* Global memory kernel version, can handle any number of tokens, but could be very slow. */ - SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO; SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN; - - const uint iters_num = CEIL_DIV(num_of_partitions, SUBGROUP_SIZE); - for (uint i = 0; i < iters_num; i++) { - const uint partition_idx = i * SUBGROUP_SIZE + sglid; - const uint max_logit_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) + - b1_idx * (TARGET_SEQ_LEN * num_of_partitions) + - target_seq_idx * (num_of_partitions) + - partition_idx; - - - if (partition_idx < num_of_partitions) { - local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logits[max_logit_offset]); - } + for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) { + local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, cur_max_logits[i]); } + local_max_logit = sub_group_reduce_max(local_max_logit); + if (sglid == 0) { + tmp_slm[sgid] = local_max_logit; + } + barrier(CLK_LOCAL_MEM_FENCE); - SOFTMAX_ACCUMULATOR_TYPE global_max = sub_group_reduce_max(local_max_logit); - - // Calculate global sum - for (uint i = 0; i < iters_num; i++) { - const uint partition_idx = i * SUBGROUP_SIZE + sglid; - const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) + - b1_idx * (TARGET_SEQ_LEN * num_of_partitions) + - target_seq_idx * (num_of_partitions) + - partition_idx; - const uint max_logit_offset = exp_sums_offset; - - if (partition_idx < num_of_partitions) { - local_exp_sum += exp_sums[exp_sums_offset] * native_exp(max_logits[max_logit_offset] - global_max); - } + if (sglid < HEAD_SIZE / SUBGROUP_SIZE) { + local_max_logit = tmp_slm[sglid]; } + local_max_logit = sub_group_reduce_max(local_max_logit); - SOFTMAX_ACCUMULATOR_TYPE global_sum = sub_group_reduce_add(local_exp_sum); + // Update exp_sum with respect to the global maximum + // SOFTMAX_ACCUMULATOR_TYPE exp_sum[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_ZERO}; + SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO; + for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) { + local_exp_sum += cur_exp_sums[i] * native_exp(cur_max_logits[i] - local_max_logit); + } + local_exp_sum = sub_group_reduce_add(local_exp_sum); + if (sglid == 0) { + tmp_slm[sgid] = local_exp_sum; + } + barrier(CLK_LOCAL_MEM_FENCE); + local_exp_sum = 0; + if (sglid < HEAD_SIZE / SUBGROUP_SIZE) { + local_exp_sum = tmp_slm[sglid]; + } - for (uint head_size_idx = 0; head_size_idx < HEAD_SIZE / SUBGROUP_SIZE; head_size_idx++) { - SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f; - for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) { + local_exp_sum = sub_group_reduce_add(local_exp_sum); + SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f; + for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) { const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) + target_seq_idx * (num_of_partitions * HEAD_SIZE) + - partition_idx * (HEAD_SIZE) + - (head_size_idx * SUBGROUP_SIZE + sglid); - - const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) + - b1_idx * (TARGET_SEQ_LEN * num_of_partitions) + - target_seq_idx * (num_of_partitions) + - partition_idx; - const uint max_logit_offset = exp_sums_offset; - - SOFTMAX_ACCUMULATOR_TYPE new_exp_sum = exp_sums[exp_sums_offset] * native_exp(max_logits[max_logit_offset] - global_max); - + partition_idx * (HEAD_SIZE) + local_id; + // (head_size_idx * SUBGROUP_SIZE + sglid); OUTPUT_TYPE out_val = tmp_out[tmp_out_offset]; - acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) * new_exp_sum / TO_SOFTMAX_ACCUMULATOR_TYPE(global_sum); - } - - const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) + - b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) + - target_seq_idx * (HEAD_SIZE) + - (head_size_idx * SUBGROUP_SIZE + sglid); - - output[out_offset] = TO_OUTPUT_TYPE(acc); + acc += TO_SOFTMAX_ACCUMULATOR_TYPE(tmp_out[tmp_out_offset]) * TO_SOFTMAX_ACCUMULATOR_TYPE(cur_exp_sums[partition_idx] * native_exp(cur_max_logits[partition_idx] - local_max_logit)); } + const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) + + b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) + + target_seq_idx * (HEAD_SIZE) + + local_id; + + output[out_offset] = TO_OUTPUT_TYPE(acc) / TO_OUTPUT_TYPE(local_exp_sum); } } diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp index 2f0174d0a45912..cd30206293f541 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp @@ -244,8 +244,8 @@ CommonDispatchData SDPAKernelOpt::SetDefault(const sdpa_params& params, size_t k } else if (kernel_idx == KernelsTypes::FINALIZATION) { dispatch_data.gws = { batch_size * heads_num, target_seq_len, - subgroup_size }; - dispatch_data.lws = { 1, 1, subgroup_size }; + head_size }; + dispatch_data.lws = { 1, 1, head_size }; } }