Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draft #2

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 125 additions & 136 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Original file line number Diff line number Diff line change
Expand Up @@ -1292,185 +1292,174 @@ KERNEL(sdpa_opt_finalization_stage)(
const uint b0_idx = batch_idx / NUM_HEADS;
const uint b1_idx = batch_idx % NUM_HEADS;
const uint target_seq_idx = get_global_id(1);
const uint local_id = get_local_id(2);
const uint sgid = get_sub_group_id();
const uint sglid = get_sub_group_local_id();

const uint offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions) +
target_seq_idx * (num_of_partitions);
__global SOFTMAX_ACCUMULATOR_TYPE* cur_exp_sums = exp_sums + offset;
__global SOFTMAX_ACCUMULATOR_TYPE* cur_max_logits = max_logits + offset;
__local SOFTMAX_ACCUMULATOR_TYPE tmp_slm[HEAD_SIZE / SUBGROUP_SIZE];

if (num_of_partitions <= SUBGROUP_SIZE * REG_VERSION_MAX_VALUES_PER_WI_LOWER) {
/* Registers kernel version, can handle up to SEQ_LEN_PARTITION_SIZE(256) * SUBGROUP_SIZE(16) * REG_VERSION_MAX_VALUES_PER_WI_LOWER(8/16) = 32768/65536 tokens */
SOFTMAX_ACCUMULATOR_TYPE exp_sum[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_ZERO};
SOFTMAX_ACCUMULATOR_TYPE max_logit[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_MIN};
SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN;

const uint iters_num = CEIL_DIV(num_of_partitions, SUBGROUP_SIZE);
for (uint i = 0; i < iters_num; i++) {
const uint partition_idx = i * SUBGROUP_SIZE + sglid;
const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions) +
target_seq_idx * (num_of_partitions) +
partition_idx;
const uint max_logit_offset = exp_sums_offset;

if (partition_idx < num_of_partitions) {
exp_sum[i] = exp_sums[exp_sums_offset];
max_logit[i] = max_logits[max_logit_offset];
local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logit[i]);
}
SOFTMAX_ACCUMULATOR_TYPE max_logits_u_exp_sum_slm[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_MIN};
SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN;
for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) {
max_logits_u_exp_sum_slm[i] = cur_max_logits[i];
local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logits_u_exp_sum_slm[i]);
}
local_max_logit = sub_group_reduce_max(local_max_logit);
if (sglid == 0) {
tmp_slm[sgid] = local_max_logit;
}
barrier(CLK_LOCAL_MEM_FENCE);

SOFTMAX_ACCUMULATOR_TYPE global_max = sub_group_reduce_max(local_max_logit);
if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_max_logit = tmp_slm[sglid];
}
local_max_logit = sub_group_reduce_max(local_max_logit);

// Update exp_sum with respect to the global maximum
for (uint i = 0; i < iters_num; i++) {
const uint partition_idx = i * SUBGROUP_SIZE + sglid;
if (partition_idx < num_of_partitions) {
exp_sum[i] = exp_sum[i] * native_exp(max_logit[i] - global_max);
local_exp_sum += exp_sum[i];
}
SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) {
SOFTMAX_ACCUMULATOR_TYPE exp_sum_new = cur_exp_sums[i] * native_exp(max_logits_u_exp_sum_slm[i] - local_max_logit);
max_logits_u_exp_sum_slm[i] = exp_sum_new;
local_exp_sum += exp_sum_new;
}
local_exp_sum = sub_group_reduce_add(local_exp_sum);
if (sglid == 0) {
tmp_slm[sgid] = local_exp_sum;
}
barrier(CLK_LOCAL_MEM_FENCE);
local_exp_sum = 0;
if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_exp_sum = tmp_slm[sglid];
}

SOFTMAX_ACCUMULATOR_TYPE global_sum = sub_group_reduce_add(local_exp_sum);

for (uint head_size_idx = 0; head_size_idx < HEAD_SIZE / SUBGROUP_SIZE; head_size_idx++) {
SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f;
for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) {
local_exp_sum = sub_group_reduce_add(local_exp_sum);
SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f;
for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) {
const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
target_seq_idx * (num_of_partitions * HEAD_SIZE) +
partition_idx * (HEAD_SIZE) +
(head_size_idx * SUBGROUP_SIZE + sglid);
partition_idx * (HEAD_SIZE) + local_id;
// (head_size_idx * SUBGROUP_SIZE + sglid);
OUTPUT_TYPE out_val = tmp_out[tmp_out_offset];
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) *
TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(exp_sum[partition_idx / SUBGROUP_SIZE], partition_idx % SUBGROUP_SIZE)) /
TO_SOFTMAX_ACCUMULATOR_TYPE(global_sum);
}
const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) +
target_seq_idx * (HEAD_SIZE) +
(head_size_idx * SUBGROUP_SIZE + sglid);

output[out_offset] = TO_OUTPUT_TYPE(acc);
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(tmp_out[tmp_out_offset]) * TO_SOFTMAX_ACCUMULATOR_TYPE(max_logits_u_exp_sum_slm[partition_idx]);
}
const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) +
target_seq_idx * (HEAD_SIZE) +
local_id;

output[out_offset] = TO_OUTPUT_TYPE(acc) / TO_OUTPUT_TYPE(local_exp_sum);
} else if (num_of_partitions <= SUBGROUP_SIZE * REG_VERSION_MAX_VALUES_PER_WI) {
/* Registers kernel version, can handle up to SEQ_LEN_PARTITION_SIZE(256) * SUBGROUP_SIZE(16) * REG_VERSION_MAX_VALUES_PER_WI(24/48) = 98304/196608 tokens */
SOFTMAX_ACCUMULATOR_TYPE exp_sum[REG_VERSION_MAX_VALUES_PER_WI] = {SOFTMAX_ACCUMULATOR_VAL_ZERO};
SOFTMAX_ACCUMULATOR_TYPE max_logit[REG_VERSION_MAX_VALUES_PER_WI] = {SOFTMAX_ACCUMULATOR_VAL_MIN};
SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
SOFTMAX_ACCUMULATOR_TYPE max_logits_u_exp_sum_slm[REG_VERSION_MAX_VALUES_PER_WI] = {SOFTMAX_ACCUMULATOR_VAL_MIN};
SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN;

const uint iters_num = CEIL_DIV(num_of_partitions, SUBGROUP_SIZE);
for (uint i = 0; i < iters_num; i++) {
const uint partition_idx = i * SUBGROUP_SIZE + sglid;
const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions) +
target_seq_idx * (num_of_partitions) +
partition_idx;
const uint max_logit_offset = exp_sums_offset;

if (partition_idx < num_of_partitions) {
exp_sum[i] = exp_sums[exp_sums_offset];
max_logit[i] = max_logits[max_logit_offset];
local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logit[i]);
}
for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) {
max_logits_u_exp_sum_slm[i] = cur_max_logits[i];
local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logits_u_exp_sum_slm[i]);
}
local_max_logit = sub_group_reduce_max(local_max_logit);
if (sglid == 0) {
tmp_slm[sgid] = local_max_logit;
}
barrier(CLK_LOCAL_MEM_FENCE);

SOFTMAX_ACCUMULATOR_TYPE global_max = sub_group_reduce_max(local_max_logit);
if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_max_logit = tmp_slm[sglid];
}
local_max_logit = sub_group_reduce_max(local_max_logit);

// Update exp_sum with respect to the global maximum
for (uint i = 0; i < iters_num; i++) {
const uint partition_idx = i * SUBGROUP_SIZE + sglid;
if (partition_idx < num_of_partitions) {
exp_sum[i] = exp_sum[i] * native_exp(max_logit[i] - global_max);
local_exp_sum += exp_sum[i];
}
// SOFTMAX_ACCUMULATOR_TYPE exp_sum[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_ZERO};
SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) {
SOFTMAX_ACCUMULATOR_TYPE exp_sum_new = cur_exp_sums[i] * native_exp(max_logits_u_exp_sum_slm[i] - local_max_logit);
max_logits_u_exp_sum_slm[i] = exp_sum_new;
local_exp_sum += exp_sum_new;
}
local_exp_sum = sub_group_reduce_add(local_exp_sum);
if (sglid == 0) {
tmp_slm[sgid] = local_exp_sum;
}
barrier(CLK_LOCAL_MEM_FENCE);
local_exp_sum = 0;
if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_exp_sum = tmp_slm[sglid];
}

SOFTMAX_ACCUMULATOR_TYPE global_sum = sub_group_reduce_add(local_exp_sum);

for (uint head_size_idx = 0; head_size_idx < HEAD_SIZE / SUBGROUP_SIZE; head_size_idx++) {
SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f;
for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) {
local_exp_sum = sub_group_reduce_add(local_exp_sum);
SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f;
for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) {
const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
target_seq_idx * (num_of_partitions * HEAD_SIZE) +
partition_idx * (HEAD_SIZE) +
(head_size_idx * SUBGROUP_SIZE + sglid);
partition_idx * (HEAD_SIZE) + local_id;
OUTPUT_TYPE out_val = tmp_out[tmp_out_offset];
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) *
TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(exp_sum[partition_idx / SUBGROUP_SIZE], partition_idx % SUBGROUP_SIZE)) /
TO_SOFTMAX_ACCUMULATOR_TYPE(global_sum);
}
const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) +
target_seq_idx * (HEAD_SIZE) +
(head_size_idx * SUBGROUP_SIZE + sglid);

output[out_offset] = TO_OUTPUT_TYPE(acc);
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(tmp_out[tmp_out_offset]) * TO_SOFTMAX_ACCUMULATOR_TYPE(max_logits_u_exp_sum_slm[partition_idx]);
}
const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) +
target_seq_idx * (HEAD_SIZE) +
local_id;

output[out_offset] = TO_OUTPUT_TYPE(acc) / TO_OUTPUT_TYPE(local_exp_sum);
} else {
/* Global memory kernel version, can handle any number of tokens, but could be very slow. */
SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
SOFTMAX_ACCUMULATOR_TYPE local_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN;

const uint iters_num = CEIL_DIV(num_of_partitions, SUBGROUP_SIZE);
for (uint i = 0; i < iters_num; i++) {
const uint partition_idx = i * SUBGROUP_SIZE + sglid;
const uint max_logit_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions) +
target_seq_idx * (num_of_partitions) +
partition_idx;


if (partition_idx < num_of_partitions) {
local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, max_logits[max_logit_offset]);
}
for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) {
local_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(local_max_logit, cur_max_logits[i]);
}
local_max_logit = sub_group_reduce_max(local_max_logit);
if (sglid == 0) {
tmp_slm[sgid] = local_max_logit;
}
barrier(CLK_LOCAL_MEM_FENCE);

SOFTMAX_ACCUMULATOR_TYPE global_max = sub_group_reduce_max(local_max_logit);

// Calculate global sum
for (uint i = 0; i < iters_num; i++) {
const uint partition_idx = i * SUBGROUP_SIZE + sglid;
const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions) +
target_seq_idx * (num_of_partitions) +
partition_idx;
const uint max_logit_offset = exp_sums_offset;

if (partition_idx < num_of_partitions) {
local_exp_sum += exp_sums[exp_sums_offset] * native_exp(max_logits[max_logit_offset] - global_max);
}
if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_max_logit = tmp_slm[sglid];
}
local_max_logit = sub_group_reduce_max(local_max_logit);

SOFTMAX_ACCUMULATOR_TYPE global_sum = sub_group_reduce_add(local_exp_sum);
// Update exp_sum with respect to the global maximum
// SOFTMAX_ACCUMULATOR_TYPE exp_sum[REG_VERSION_MAX_VALUES_PER_WI_LOWER] = {SOFTMAX_ACCUMULATOR_VAL_ZERO};
SOFTMAX_ACCUMULATOR_TYPE local_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
for (uint i = local_id; i < num_of_partitions; i+= HEAD_SIZE) {
local_exp_sum += cur_exp_sums[i] * native_exp(cur_max_logits[i] - local_max_logit);
}
local_exp_sum = sub_group_reduce_add(local_exp_sum);
if (sglid == 0) {
tmp_slm[sgid] = local_exp_sum;
}
barrier(CLK_LOCAL_MEM_FENCE);
local_exp_sum = 0;
if (sglid < HEAD_SIZE / SUBGROUP_SIZE) {
local_exp_sum = tmp_slm[sglid];
}

for (uint head_size_idx = 0; head_size_idx < HEAD_SIZE / SUBGROUP_SIZE; head_size_idx++) {
SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f;
for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) {
local_exp_sum = sub_group_reduce_add(local_exp_sum);
SOFTMAX_ACCUMULATOR_TYPE acc = 0.0f;
for (uint partition_idx = 0; partition_idx < num_of_partitions; partition_idx++) {
const uint tmp_out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions * HEAD_SIZE) +
target_seq_idx * (num_of_partitions * HEAD_SIZE) +
partition_idx * (HEAD_SIZE) +
(head_size_idx * SUBGROUP_SIZE + sglid);

const uint exp_sums_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * num_of_partitions) +
b1_idx * (TARGET_SEQ_LEN * num_of_partitions) +
target_seq_idx * (num_of_partitions) +
partition_idx;
const uint max_logit_offset = exp_sums_offset;

SOFTMAX_ACCUMULATOR_TYPE new_exp_sum = exp_sums[exp_sums_offset] * native_exp(max_logits[max_logit_offset] - global_max);

partition_idx * (HEAD_SIZE) + local_id;
// (head_size_idx * SUBGROUP_SIZE + sglid);
OUTPUT_TYPE out_val = tmp_out[tmp_out_offset];
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(out_val) * new_exp_sum / TO_SOFTMAX_ACCUMULATOR_TYPE(global_sum);
}

const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) +
target_seq_idx * (HEAD_SIZE) +
(head_size_idx * SUBGROUP_SIZE + sglid);

output[out_offset] = TO_OUTPUT_TYPE(acc);
acc += TO_SOFTMAX_ACCUMULATOR_TYPE(tmp_out[tmp_out_offset]) * TO_SOFTMAX_ACCUMULATOR_TYPE(cur_exp_sums[partition_idx] * native_exp(cur_max_logits[partition_idx] - local_max_logit));
}
const uint out_offset = b0_idx * (NUM_HEADS * TARGET_SEQ_LEN * HEAD_SIZE) +
b1_idx * (TARGET_SEQ_LEN * HEAD_SIZE) +
target_seq_idx * (HEAD_SIZE) +
local_id;

output[out_offset] = TO_OUTPUT_TYPE(acc) / TO_OUTPUT_TYPE(local_exp_sum);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ CommonDispatchData SDPAKernelOpt::SetDefault(const sdpa_params& params, size_t k
} else if (kernel_idx == KernelsTypes::FINALIZATION) {
dispatch_data.gws = { batch_size * heads_num,
target_seq_len,
subgroup_size };
dispatch_data.lws = { 1, 1, subgroup_size };
head_size };
dispatch_data.lws = { 1, 1, head_size };
}
}

Expand Down
Loading