diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp index 8b7caddd61e491..10091bdad5b9c7 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp @@ -19,7 +19,7 @@ ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern( auto kv_past = pattern::wrap_type({pattern::any_input()}); auto kv_gather = pattern::wrap_type({kv_past, pattern::any_input(), pattern::any_input()}); auto kv_shape = pattern::wrap_type({kv_gather}); - auto seq = pattern::wrap_type({kv_past, pattern::any_input(), pattern::any_input()}); + auto seq = pattern::wrap_type({kv_shape, pattern::any_input(), pattern::any_input()}); ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { // TODO: Check that seq has axis that really takes sequence len but not any other dimension -- use symbolics or diff --git a/src/core/src/pass/sdpa_to_paged_attention.cpp b/src/core/src/pass/sdpa_to_paged_attention.cpp index c2c899dd515149..6f1d182141b162 100644 --- a/src/core/src/pass/sdpa_to_paged_attention.cpp +++ b/src/core/src/pass/sdpa_to_paged_attention.cpp @@ -38,7 +38,7 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr(std::make_shared(model->input("input_ids")), + auto cur_seq_len = std::make_shared(std::make_shared(model->input("input_ids")), v0::Constant::create(element::i64, Shape{}, {1}), v0::Constant::create(element::i64, Shape{}, {0})); auto prev_max_seq_len = std::make_shared(max_context_len, cur_seq_len);