Skip to content

Commit

Permalink
SDPA decomposition tests are extended
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Jan 24, 2025
1 parent b35bc5d commit 1b63229
Showing 1 changed file with 59 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ using namespace ov;
using namespace testing;

const std::shared_ptr<ov::Node> scaled_dot_product_attention_decomposition(
const std::shared_ptr<ov::Node> query,
const std::shared_ptr<ov::Node> key,
const std::shared_ptr<ov::Node> value,
const std::shared_ptr<ov::Node> attention_mask,
const std::shared_ptr<ov::Node> scale,
const bool casual);
std::shared_ptr<ov::Node> query,
std::shared_ptr<ov::Node> key,
std::shared_ptr<ov::Node> value,
std::shared_ptr<ov::Node> attention_mask,
std::shared_ptr<ov::Node> scale,
bool casual);

TEST_F(TransformationTestsF, ScaledDotProductAttentionDecompositionStaticBasic) {
const PartialShape query_shape{1, 32, 32};
Expand Down Expand Up @@ -129,6 +129,34 @@ TEST_F(TransformationTestsF, ScaledDotProductAttentionDecompositionStaticBroadca
}
}

TEST_F(TransformationTestsF, ScaledDotProductAttentionCasualPartiallyDynamic) {
const PartialShape query_shape{-1, -1, 24, 64};
const PartialShape key_shape{-1, -1, 24, 64};
const PartialShape value_shape{-1, -1, -1, 64};
const PartialShape attention_mask_shape{-1, -1, -1, -1};
const auto casual = true;

const auto query = std::make_shared<ov::op::v0::Parameter>(element::f32, query_shape);
const auto key = std::make_shared<ov::op::v0::Parameter>(element::f32, key_shape);
const auto value = std::make_shared<ov::op::v0::Parameter>(element::f32, value_shape);
const auto attention_mask = std::make_shared<ov::op::v0::Parameter>(element::f32, attention_mask_shape);
{
const auto scaled_dot_product_attention =
std::make_shared<ov::op::v13::ScaledDotProductAttention>(query, key, value, attention_mask, casual);

model = std::make_shared<ov::Model>(NodeVector{scaled_dot_product_attention},
ParameterVector{query, key, value, attention_mask});
manager.register_pass<ov::pass::ScaledDotProductAttentionDecomposition>();
}

{
const auto scaled_dot_product_attention =
scaled_dot_product_attention_decomposition(query, key, value, attention_mask, nullptr, casual);
model_ref = std::make_shared<ov::Model>(NodeVector{scaled_dot_product_attention},
ParameterVector{query, key, value, attention_mask});
}
}

TEST_F(TransformationTestsF, ScaledDotProductAttentionDecompositionDynamic) {
const PartialShape query_shape{-1, -1, -1};
const PartialShape key_shape{-1, -1, -1};
Expand Down Expand Up @@ -160,12 +188,12 @@ TEST_F(TransformationTestsF, ScaledDotProductAttentionDecompositionDynamic) {
}

const std::shared_ptr<ov::Node> scaled_dot_product_attention_decomposition(
const std::shared_ptr<ov::Node> query,
const std::shared_ptr<ov::Node> key,
const std::shared_ptr<ov::Node> value,
const std::shared_ptr<ov::Node> attention_mask,
const std::shared_ptr<ov::Node> scale,
const bool casual) {
std::shared_ptr<ov::Node> query,
std::shared_ptr<ov::Node> key,
std::shared_ptr<ov::Node> value,
std::shared_ptr<ov::Node> attention_mask,
std::shared_ptr<ov::Node> scale,
bool casual) {
const auto q_shape = std::make_shared<ov::op::v3::ShapeOf>(query, element::i32);
const auto k_shape = std::make_shared<ov::op::v3::ShapeOf>(key, element::i32);
const auto minus_one = ov::op::v0::Constant::create(element::i32, Shape{}, {-1});
Expand All @@ -175,6 +203,23 @@ const std::shared_ptr<ov::Node> scaled_dot_product_attention_decomposition(
const auto one_f = std::make_shared<ov::op::v1::ConvertLike>(one_i, query);
const auto zero_f = std::make_shared<ov::op::v1::ConvertLike>(zero_i, query);

auto extract_dim = [&zero_i](const std::shared_ptr<ov::op::v3::ShapeOf>& shape_of,
const int64_t idx) -> std::shared_ptr<ov::Node> {
const auto& shape = shape_of->get_input_partial_shape(0);
const auto& dim = shape[idx];
if (dim.is_static()) {
return ov::op::v0::Constant::create(element::i32, Shape{}, {dim.get_length()});
}
const auto dim_to_extract_const = ov::op::v0::Constant::create(element::i32, Shape{}, {idx});
return std::make_shared<ov::op::v8::Gather>(shape_of, dim_to_extract_const, zero_i);
};

if (scale == nullptr) {
scale = extract_dim(q_shape, -1);
scale = std::make_shared<ov::op::v1::ConvertLike>(scale, query);
auto sqrt_scale = std::make_shared<ov::op::v0::Sqrt>(scale);
scale = std::make_shared<ov::op::v1::Divide>(one_f, sqrt_scale);
}
const auto q_scaled = std::make_shared<ov::op::v1::Multiply>(query, scale);
auto k_rank = std::make_shared<ov::op::v3::ShapeOf>(k_shape, element::i32)->output(0);
const auto k_last_dim = std::make_shared<ov::op::v1::Add>(k_rank, minus_one);
Expand Down Expand Up @@ -204,8 +249,8 @@ const std::shared_ptr<ov::Node> scaled_dot_product_attention_decomposition(
atten_mask = mask;
}
} else {
const auto target_s_len = std::make_shared<ov::op::v8::Gather>(q_shape, minus_two, zero_i);
const auto source_s_len = std::make_shared<ov::op::v8::Gather>(k_shape, minus_two, zero_i);
const auto target_s_len = extract_dim(q_shape, -2);
const auto source_s_len = extract_dim(k_shape, -2);
const auto ssl = std::make_shared<ov::op::v0::Unsqueeze>(source_s_len, zero_i);
const auto tsl = std::make_shared<ov::op::v0::Unsqueeze>(target_s_len, zero_i);
const auto mask_shape = std::make_shared<ov::op::v0::Concat>(OutputVector{tsl, ssl}, 0);
Expand Down

0 comments on commit 1b63229

Please sign in to comment.