Skip to content

Commit

Permalink
Fix GPU names
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Dec 31, 2024
1 parent c4a8183 commit 6a4a852
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct paged_attention : public primitive_base<paged_attention> {
ob << heads_num;
ob << kv_heads_num;
ob << has_alibi;
ob << has_rotation_coefficients;
ob << has_rotated_blocks;
}

void load(BinaryInputBuffer& ib) override {
Expand All @@ -49,14 +49,14 @@ struct paged_attention : public primitive_base<paged_attention> {
ib >> heads_num;
ib >> kv_heads_num;
ib >> has_alibi;
ib >> has_rotation_coefficients;
ib >> has_rotated_blocks;
}

optional_value<float> scale_val{};
size_t head_size = 0;
size_t heads_num = 0;
size_t kv_heads_num = 0;
bool has_alibi = false;
bool has_rotation_coefficients = false;
bool has_rotated_blocks = false;
};
} // namespace cldnn
5 changes: 3 additions & 2 deletions src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
// dependency
args.inputs.push_back(instance.subsequence_begins_memory_ptr());
}
if (desc->has_rotation_coefficients) {
args.inputs.push_back(instance.rotation_coefficients_memory_ptr());
if (desc->has_rotated_blocks) {
args.inputs.push_back(instance.rotated_block_indices_memory_ptr());
args.inputs.push_back(instance.rotation_deltas_memory_ptr());
args.inputs.push_back(instance.rotation_trig_lut_memory_ptr());
}
} else if (kernel_idx == 4) {
// Output scores calculation kernel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,6 @@ class typed_primitive_inst<paged_attention> : public typed_primitive_inst_base<p
memory::ptr rotation_deltas_memory_ptr() const { return input_memory_ptr(14); }
memory::ptr rotation_trig_lut_memory_ptr() const { return input_memory_ptr(15); }

memory::ptr rotation_coefficients_memory_ptr() const {
return input_memory_ptr(13);
}
memory::ptr rotated_block_indices_memory_ptr() const {
return input_memory_ptr(14);
}

std::shared_ptr<network> prefill_network;

protected:
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_gpu/src/graph/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ std::string paged_attention_inst::to_string(const paged_attention_node& node) {
paged_attention_info.add("kv_heads_num", desc->kv_heads_num);
paged_attention_info.add("scale", desc->scale_val.value_or(1.0f));
paged_attention_info.add("has_alibi", desc->has_alibi);
paged_attention_info.add("has_rotation_coefficients", desc->has_rotation_coefficients);
paged_attention_info.add("has_rotated_blocks", desc->has_rotated_blocks);
node_info->add("paged_attention primitive info", paged_attention_info);
node_info->dump(primitive_description);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ KERNEL(pa_sdpa_opt)(
#if HAS_ALIBI
const __global ALIBI_INPUT_TYPE* alibi_slopes,
#endif
#if HAS_ROTATION_COEFFICIENTS
const __global INPUT8_TYPE* rotation_coefficients,
const __global INPUT9_TYPE* rotated_block_indices,
#if HAS_ROTATED_BLOCKS
const __global INPUT8_TYPE* rotated_block_indices,
const __global INPUT9_TYPE* rotation_deltas,
const __global INPUT10_TYPE* rotated_block_indices,
#endif
__global OUTPUT_TYPE* output,
#if PAGED_ATTENTION_SCORES_OUTPUT
Expand All @@ -67,8 +68,9 @@ KERNEL(pa_sdpa_opt)(
// subsequence_begins: [sequences_num + 1]
// block_indices: [used_blocks_num]
// block_indices_begins: [sequences_num + 1]
// rotation_coefficients: [num_rotated_blocks * PAGED_ATTENTION_BLOCK_SIZE]
// rotated_block_indices: [num_rotated_blocks ]
// rotation_deltas [num_rotated_blocks, 1 || PAGED_ATTENTION_BLOCK_SIZE ]
// rotation_trig_lut [MAX_CONTEXT_LEN, HEAD_SIZE]
//
// Output shapes:
// output: [sequences_num, HEADS_NUM * HEAD_SIZE]
Expand Down Expand Up @@ -154,7 +156,7 @@ KERNEL(pa_sdpa_opt)(
}
#endif

#ifdef HAS_ROTATION_COEFFICIENTS
#ifdef HAS_ROTATED_BLOCKS
// TODO (vshampor): add cache block rotation at this spot
#endif

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
jit.AddConstant(MakeJitConstant("PAGED_ATTENTION_SCORES_OUTPUT", 1));
}

if (params.conf.has_rotation_coefficients_input)
jit.AddConstant(MakeJitConstant("HAS_ROTATION_COEFFICIENTS", 1));
if (params.conf.has_rotated_blocks)
jit.AddConstant(MakeJitConstant("HAS_ROTATED_BLOCKS", 1));

if (kernel_idx == KernelsTypes::MULTI_TOKENS || kernel_idx == KernelsTypes::FINALIZATION_MULTI_TOKENS)
jit.AddConstant(MakeJitConstant("MULTI_TOKENS_PROCESSING", 1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ struct sdpa_configuration {
int64_t paged_attention_max_len = 0;
bool has_const_scale_val = false;
float scale_val = 0.f;
bool has_rotation_coefficients_input = false;
bool has_rotated_blocks = false;
};

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
10 changes: 5 additions & 5 deletions src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared

const size_t scale_idx = 9;
const size_t alibi_idx = 11;
const size_t rotation_coefficients_idx = 13;
const size_t rotated_block_indices_idx = 13;

std::shared_ptr<ov::op::v0::Constant> scale_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(op->get_input_node_shared_ptr(scale_idx));
if (scale_const) {
Expand All @@ -64,10 +64,10 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared

prim.num_outputs = 1;

std::shared_ptr<ov::op::v0::Constant> rotation_coefficients_const =
std::dynamic_pointer_cast<ov::op::v0::Constant>(op->get_input_node_shared_ptr(rotation_coefficients_idx));
OPENVINO_ASSERT(rotation_coefficients_const != nullptr);
prim.has_rotation_coefficients = ov::shape_size(rotation_coefficients_const->get_output_shape(0)) > 0;
std::shared_ptr<ov::op::v0::Constant> rotated_block_indices_const =
std::dynamic_pointer_cast<ov::op::v0::Constant>(op->get_input_node_shared_ptr(rotated_block_indices_idx));
OPENVINO_ASSERT(rotated_block_indices_const != nullptr);
prim.has_rotated_blocks = ov::shape_size(rotated_block_indices_const->get_output_shape(0)) > 0;

if (op->get_output_size() > 1) {
const auto scores_output_idx = 1;
Expand Down

0 comments on commit 6a4a852

Please sign in to comment.