Skip to content

Commit

Permalink
draft
Browse files Browse the repository at this point in the history
  • Loading branch information
zaixing-wang committed Jan 2, 2025
1 parent 2e25c87 commit 0e144a8
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 148 deletions.
270 changes: 124 additions & 146 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Original file line number Diff line number Diff line change
Expand Up @@ -1292,186 +1292,164 @@ KERNEL(sdpa_opt_finalization_stage)(
const uint b0_idx = batch_idx / NUM_HEADS;
const uint b1_idx = batch_idx % NUM_HEADS;
const uint target_seq_idx = get_global_id(1);
const uint local_id = get_local_id(2);
const uint sgid = get_sub_group_id();
const uint sglid = get_sub_group_local_id();

const uint offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions) +
target_seq_idx * (num_of_partitions);
__global SOFTMAX_ACCUMULATOR_TYPE* cur_exp_sums = exp_sums + offset;
__global SOFTMAX_ACCUMULATOR_TYPE* cur_max_logits = max_logits + offset;
__local SOFTMAX_ACCUMULATOR_TYPE tmp_slm[HEAD_SIZE / SUBGROUP_SIZE];
if (num_of_partitions <= SUBGROUP_SIZE * REG_VERSION_MAX_VALUES_PER_WI_LOWER) {
/* Registers kernel version, can handle up to SEQ_LEN_PARTITION_SIZE(256) * SUBGROUP_SIZE(16) * REG_VERSION_MAX_VALUES_PER_WI_LOWER(8/16) = 32768/65536 tokens */
SOFTMAX_ACCUMULATOR_TYPE exp_sum[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_ZERO};
SOFTMAX_ACCUMULATOR_TYPE max_logit[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_MIN};
SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
SOFTMAX_ACCUMULATOR_TYPE max_logits_u_exp_sum_slm[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_MIN};
SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN;

const uint iters_num = CEIL_DIV(num_of_partitions, SUBGROUP_SIZE);
for (uint i = 0; i < iters_num; i++) {
const uint partition_idx = i * SUBGROUP_SIZE + sglid;
const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions) +
target_seq_idx * (num_of_partitions) +
partition_idx;
const uint max_logit_offset = exp_sums_offset;

if (partition_idx < num_of_partitions) {
exp_sum[i] = exp_sums[exp_sums_offset];
max_logit[i] = max_logits[max_logit_offset];
local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logit[i]);
}
// const uint iters_num = CEIL_DIV(num_of_partitions, SUBGROUP_SIZE);
for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) {
max_logits_u_exp_sum_slm[i] = cur_max_logits[i];
local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logits_u_exp_sum_slm[i]);
}

SOFTMAX_ACCUMULATOR_TYPE global_max = sub_group_reduce_max(local_max_logit);

local_max_logit = sub_group_reduce_max(local_max_logit);
if (sglid == 0) {
tmp_slm[sgid] = local_max_logit;
}
barrier(CLK_LOCAL_MEM_FENCE);
if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_max_logit = tmp_slm[sglid];
}
local_max_logit = sub_group_reduce_max(local_max_logit);
// Update exp_sum with respect to the global maximum
for (uint i = 0; i < iters_num; i++) {
const uint partition_idx = i * SUBGROUP_SIZE + sglid;
if (partition_idx < num_of_partitions) {
exp_sum[i] = exp_sum[i] * native_exp(max_logit[i] - global_max);
local_exp_sum += exp_sum[i];
}
// SOFTMAX_ACCUMULATOR_TYPE exp_sum[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_ZERO};
SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) {
SOFTMAX_ACCUMULATOR_TYPE exp_sum_new = cur_exp_sums[i] * native_exp(max_logits_u_exp_sum_slm[i] - local_max_logit);
max_logits_u_exp_sum_slm[i] = exp_sum_new;
local_exp_sum += exp_sum_new;
}

SOFTMAX_ACCUMULATOR_TYPE global_sum = sub_group_reduce_add(local_exp_sum);

for (uint head_size_idx = 0; head_size_idx < HEAD_SIZE / SUBGROUP_SIZE; head_size_idx++) {
SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f;
for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) {
local_exp_sum = sub_group_reduce_add(local_exp_sum);
if (sglid == 0) {
tmp_slm[sgid] = local_exp_sum;
}
barrier(CLK_LOCAL_MEM_FENCE);
local_exp_sum = 0;
if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_exp_sum = tmp_slm[sglid];
}
local_exp_sum = sub_group_reduce_add(local_exp_sum);
SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f;
for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) {
const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
target_seq_idx * (num_of_partitions * HEAD_SIZE) +
partition_idx * (HEAD_SIZE) +
(head_size_idx * SUBGROUP_SIZE + sglid);
partition_idx * (HEAD_SIZE) + local_id;
// (head_size_idx * SUBGROUP_SIZE + sglid);
OUTPUT_TYPE out_val = tmp_out[tmp_out_offset];
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) *
TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(exp_sum[partition_idx / SUBGROUP_SIZE], partition_idx % SUBGROUP_SIZE)) /
TO_SOFTMAX_ACCUMULATOR_TYPE(global_sum);
}
const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) +
target_seq_idx * (HEAD_SIZE) +
(head_size_idx * SUBGROUP_SIZE + sglid);

output[out_offset] = TO_OUTPUT_TYPE(acc);
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(tmp_out[tmp_out_offset]) * TO_SOFTMAX_ACCUMULATOR_TYPE(max_logits_u_exp_sum_slm[partition_idx]);
}
const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) +
target_seq_idx * (HEAD_SIZE) +
local_id;
output[out_offset] = TO_OUTPUT_TYPE(acc) / TO_OUTPUT_TYPE(local_exp_sum);
} else if (num_of_partitions <= SUBGROUP_SIZE * REG_VERSION_MAX_VALUES_PER_WI) {
/* Registers kernel version, can handle up to SEQ_LEN_PARTITION_SIZE(256) * SUBGROUP_SIZE(16) * REG_VERSION_MAX_VALUES_PER_WI(24/48) = 98304/196608 tokens */
SOFTMAX_ACCUMULATOR_TYPE exp_sum[REG_VERSION_MAX_VALUES_PER_WI] = {SOFTMAX_ACCUMULATOR_VAL_ZERO};
SOFTMAX_ACCUMULATOR_TYPE max_logit[REG_VERSION_MAX_VALUES_PER_WI] = {SOFTMAX_ACCUMULATOR_VAL_MIN};
SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
SOFTMAX_ACCUMULATOR_TYPE max_logits_u_exp_sum_slm[REG_VERSION_MAX_VALUES_PER_WI] = {SOFTMAX_ACCUMULATOR_VAL_MIN};
SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN;

const uint iters_num = CEIL_DIV(num_of_partitions, SUBGROUP_SIZE);
for (uint i = 0; i < iters_num; i++) {
const uint partition_idx = i * SUBGROUP_SIZE + sglid;
const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions) +
target_seq_idx * (num_of_partitions) +
partition_idx;
const uint max_logit_offset = exp_sums_offset;

if (partition_idx < num_of_partitions) {
exp_sum[i] = exp_sums[exp_sums_offset];
max_logit[i] = max_logits[max_logit_offset];
local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logit[i]);
}
for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) {
max_logits_u_exp_sum_slm[i] = cur_max_logits[i];
local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logits_u_exp_sum_slm[i]);
}

SOFTMAX_ACCUMULATOR_TYPE global_max = sub_group_reduce_max(local_max_logit);

local_max_logit = sub_group_reduce_max(local_max_logit);
if (sglid == 0) {
tmp_slm[sgid] = local_max_logit;
}
barrier(CLK_LOCAL_MEM_FENCE);
if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_max_logit = tmp_slm[sglid];
}
local_max_logit = sub_group_reduce_max(local_max_logit);
// Update exp_sum with respect to the global maximum
for (uint i = 0; i < iters_num; i++) {
const uint partition_idx = i * SUBGROUP_SIZE + sglid;
if (partition_idx < num_of_partitions) {
exp_sum[i] = exp_sum[i] * native_exp(max_logit[i] - global_max);
local_exp_sum += exp_sum[i];
}
// SOFTMAX_ACCUMULATOR_TYPE exp_sum[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_ZERO};
SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) {
SOFTMAX_ACCUMULATOR_TYPE exp_sum_new = cur_exp_sums[i] * native_exp(max_logits_u_exp_sum_slm[i] - local_max_logit);
max_logits_u_exp_sum_slm[i] = exp_sum_new;
local_exp_sum += exp_sum_new;
}

SOFTMAX_ACCUMULATOR_TYPE global_sum = sub_group_reduce_add(local_exp_sum);

for (uint head_size_idx = 0; head_size_idx < HEAD_SIZE / SUBGROUP_SIZE; head_size_idx++) {
SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f;
for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) {
local_exp_sum = sub_group_reduce_add(local_exp_sum);
if (sglid == 0) {
tmp_slm[sgid] = local_exp_sum;
}
barrier(CLK_LOCAL_MEM_FENCE);
local_exp_sum = 0;
if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_exp_sum = tmp_slm[sglid];
}
local_exp_sum = sub_group_reduce_add(local_exp_sum);
SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f;
for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) {
const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
target_seq_idx * (num_of_partitions * HEAD_SIZE) +
partition_idx * (HEAD_SIZE) +
(head_size_idx * SUBGROUP_SIZE + sglid);
partition_idx * (HEAD_SIZE) + local_id;
// (head_size_idx * SUBGROUP_SIZE + sglid);
OUTPUT_TYPE out_val = tmp_out[tmp_out_offset];
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) *
TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(exp_sum[partition_idx / SUBGROUP_SIZE], partition_idx % SUBGROUP_SIZE)) /
TO_SOFTMAX_ACCUMULATOR_TYPE(global_sum);
}
const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) +
target_seq_idx * (HEAD_SIZE) +
(head_size_idx * SUBGROUP_SIZE + sglid);

output[out_offset] = TO_OUTPUT_TYPE(acc);
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(tmp_out[tmp_out_offset]) * TO_SOFTMAX_ACCUMULATOR_TYPE(max_logits_u_exp_sum_slm[partition_idx]);
}
const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) +
target_seq_idx * (HEAD_SIZE) +
local_id;
output[out_offset] = TO_OUTPUT_TYPE(acc) / TO_OUTPUT_TYPE(local_exp_sum);
} else {
/* Global memory kernel version, can handle any number of tokens, but could be very slow. */
SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN;

const uint iters_num = CEIL_DIV(num_of_partitions, SUBGROUP_SIZE);
for (uint i = 0; i < iters_num; i++) {
const uint partition_idx = i * SUBGROUP_SIZE + sglid;
const uint max_logit_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions) +
target_seq_idx * (num_of_partitions) +
partition_idx;


if (partition_idx < num_of_partitions) {
local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logits[max_logit_offset]);
}
for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) {
local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, cur_max_logits[i]);
}

SOFTMAX_ACCUMULATOR_TYPE global_max = sub_group_reduce_max(local_max_logit);

// Calculate global sum
for (uint i = 0; i < iters_num; i++) {
const uint partition_idx = i * SUBGROUP_SIZE + sglid;
const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions) +
target_seq_idx * (num_of_partitions) +
partition_idx;
const uint max_logit_offset = exp_sums_offset;

if (partition_idx < num_of_partitions) {
local_exp_sum += exp_sums[exp_sums_offset] * native_exp(max_logits[max_logit_offset] - global_max);
}
local_max_logit = sub_group_reduce_max(local_max_logit);
if (sglid == 0) {
tmp_slm[sgid] = local_max_logit;
}

SOFTMAX_ACCUMULATOR_TYPE global_sum = sub_group_reduce_add(local_exp_sum);

for (uint head_size_idx = 0; head_size_idx < HEAD_SIZE / SUBGROUP_SIZE; head_size_idx++) {
SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f;
for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) {
barrier(CLK_LOCAL_MEM_FENCE);
if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_max_logit = tmp_slm[sglid];
}
local_max_logit = sub_group_reduce_max(local_max_logit);
// Update exp_sum with respect to the global maximum
// SOFTMAX_ACCUMULATOR_TYPE exp_sum[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_ZERO};
SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) {
local_exp_sum += cur_exp_sums[i] * native_exp(cur_max_logits[i] - local_max_logit);
}
local_exp_sum = sub_group_reduce_add(local_exp_sum);
if (sglid == 0) {
tmp_slm[sgid] = local_exp_sum;
}
barrier(CLK_LOCAL_MEM_FENCE);
local_exp_sum = 0;
if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_exp_sum = tmp_slm[sglid];
}
local_exp_sum = sub_group_reduce_add(local_exp_sum);
SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f;
for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) {
const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
target_seq_idx * (num_of_partitions * HEAD_SIZE) +
partition_idx * (HEAD_SIZE) +
(head_size_idx * SUBGROUP_SIZE + sglid);

const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions) +
target_seq_idx * (num_of_partitions) +
partition_idx;
const uint max_logit_offset = exp_sums_offset;

SOFTMAX_ACCUMULATOR_TYPE new_exp_sum = exp_sums[exp_sums_offset] * native_exp(max_logits[max_logit_offset] - global_max);

partition_idx * (HEAD_SIZE) + local_id;
// (head_size_idx * SUBGROUP_SIZE + sglid);
OUTPUT_TYPE out_val = tmp_out[tmp_out_offset];
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) * new_exp_sum / TO_SOFTMAX_ACCUMULATOR_TYPE(global_sum);
}

const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) +
target_seq_idx * (HEAD_SIZE) +
(head_size_idx * SUBGROUP_SIZE + sglid);

output[out_offset] = TO_OUTPUT_TYPE(acc);
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(tmp_out[tmp_out_offset]) * TO_SOFTMAX_ACCUMULATOR_TYPE(cur_exp_sums[partition_idx] * native_exp(cur_max_logits[partition_idx] - local_max_logit));
}
const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) +
target_seq_idx * (HEAD_SIZE) +
local_id;
output[out_offset] = TO_OUTPUT_TYPE(acc) / TO_OUTPUT_TYPE(local_exp_sum);
}
}


#endif
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ CommonDispatchData SDPAKernelOpt::SetDefault(const sdpa_params& params, size_t k
} else if (kernel_idx == KernelsTypes::FINALIZATION) {
dispatch_data.gws = { batch_size * heads_num,
target_seq_len,
subgroup_size };
dispatch_data.lws = { 1, 1, subgroup_size };
head_size };
dispatch_data.lws = { 1, 1, head_size };
}
}

Expand Down

0 comments on commit 0e144a8

Please sign in to comment.