Skip to content

Commit

Permalink
Fix build
Browse files Browse the repository at this point in the history
Signed-off-by: Evgeniia Nugmanova <[email protected]>
  • Loading branch information
jane-intel committed Dec 24, 2024
1 parent 658c6df commit 7c183db
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/core/src/pass/sdpa_to_paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,21 +149,25 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
}
}

for (const std::string& param_name : {"beam_idx", "attention_mask"}) {
for (const auto& param_name : {"beam_idx", "attention_mask"}) {
if (auto param = get_parameter(model, param_name)) {
if (param_name == "attention_mask" && param->output(0).get_target_inputs().size() == 1)
param->output(0).get_target_inputs().begin()->replace_source_output(input_ids_node->output(0));
auto target_inputs = param->output(0).get_target_inputs();
if (!strcmp(param_name, "attention_mask") && target_inputs.size() == 1 &&
ov::is_type<op::util::ShapeOfBase>(target_inputs.begin()->get_node())) {
target_inputs.begin()->replace_source_output(input_ids_node->output(0));
target_inputs = param->output(0).get_target_inputs();
}
model->remove_parameter(param);

if (param->output(0).get_target_inputs().size() == 0) {
if (!target_inputs.empty()) {
std::stringstream consumers;
consumers << std::endl;
for (auto& input : param->output(0).get_target_inputs()) {
for (auto& input : target_inputs) {
consumers << *input.get_node() << std::endl;
}
OPENVINO_ASSERT(param->output(0).get_target_inputs().size() == 0,
OPENVINO_ASSERT(target_inputs.empty(),
"PagedAttention transformation failed: couldn't remove ",
param->output(0).get_target_inputs().size(),
target_inputs.size(),
" inputs of ",
param_name,
" input: ",
Expand Down

0 comments on commit 7c183db

Please sign in to comment.