Skip to content

Commit

Permalink
Fix PagedAttention PrevSequenceLengthPattern pattern (openvinotoolkit…
Browse files Browse the repository at this point in the history
…#24514)

Fix PagedAttention PrevSequenceLengthPattern pattern

Fix pattern of PrevSequenceLengthPattern that makes it match twice than
needed resulting in the incorrect inference.
Provide the correct version of Gather (v8::Gather) and amend
PrevSequenceLengthPattern pattern for the correct matching.

Signed-off-by: Andrii Staikov <[email protected]>

### Details:
 - *item1*
 - *...*

### Tickets:
 - *ticket-id*

Signed-off-by: Andrii Staikov <[email protected]>
Co-authored-by: Ivan Tikhonov <[email protected]>
  • Loading branch information
CuriousPanCake and itikhono authored May 15, 2024
1 parent a2cd7be commit 6950460
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(
auto kv_past = pattern::wrap_type<v6::ReadValue>({pattern::any_input()});
auto kv_gather = pattern::wrap_type<v8::Gather>({kv_past, pattern::any_input(), pattern::any_input()});
auto kv_shape = pattern::wrap_type<v3::ShapeOf>({kv_gather});
auto seq = pattern::wrap_type<v8::Gather>({kv_past, pattern::any_input(), pattern::any_input()});
auto seq = pattern::wrap_type<v8::Gather>({kv_shape, pattern::any_input(), pattern::any_input()});

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
// TODO: Check that seq has axis that really takes sequence len but not any other dimension -- use symbolics or
Expand Down
2 changes: 1 addition & 1 deletion src/core/src/pass/sdpa_to_paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
};
auto sliding_window = v0::Constant::create(element::i32, Shape{}, {0}); // sliding_window

auto cur_seq_len = std::make_shared<v1::Gather>(std::make_shared<v3::ShapeOf>(model->input("input_ids")),
auto cur_seq_len = std::make_shared<v8::Gather>(std::make_shared<v3::ShapeOf>(model->input("input_ids")),
v0::Constant::create(element::i64, Shape{}, {1}),
v0::Constant::create(element::i64, Shape{}, {0}));
auto prev_max_seq_len = std::make_shared<v1::Subtract>(max_context_len, cur_seq_len);
Expand Down

0 comments on commit 6950460

Please sign in to comment.