diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 6018302b507cd7..bd213d0dcb9bfe 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -686,6 +686,54 @@ Tensor _safe_softmax( // reuse storage for out return at::where_out(out, masked_rows, zero, out); } + +Tensor expand_boolean_mask_for_hip( + const std::optional& attn_mask, + const Tensor& query) { + + if (!attn_mask.has_value() || !attn_mask->defined()) { + return {}; + } + + const auto& mask = attn_mask.value(); + + // Only expand boolean masks on HIP backend + if (!query.device().is_hip() || mask.dtype() != at::kBool) { + return mask; + } + + const auto query_sizes = query.sizes(); + const auto batch_size = query_sizes[0]; + const auto num_heads = query_sizes[1]; + const auto seq_len = query_sizes[2]; + + // Handle 2D boolean mask: [batch, seq_len] -> [batch, heads, seq_len, seq_len] + if (mask.dim() == 2) { + TORCH_CHECK(mask.size(0) == batch_size, + "Attention mask batch size (", mask.size(0), + ") doesn't match query batch size (", batch_size, ")"); + TORCH_CHECK(mask.size(1) == seq_len, + "Attention mask sequence length (", mask.size(1), + ") doesn't match query sequence length (", seq_len, ")"); + + return mask.unsqueeze(1).unsqueeze(1).expand({batch_size, num_heads, seq_len, seq_len}); + } + + // Handle 3D boolean mask: [batch, seq_len, seq_len] -> [batch, heads, seq_len, seq_len] + if (mask.dim() == 3) { + TORCH_CHECK(mask.size(0) == batch_size, + "Attention mask batch size (", mask.size(0), + ") doesn't match query batch size (", batch_size, ")"); + TORCH_CHECK(mask.size(1) == seq_len && mask.size(2) == seq_len, + "Attention mask sequence dimensions ([", mask.size(1), ", ", mask.size(2), + "]) don't match query sequence length (", seq_len, ")"); + + return mask.unsqueeze(1).expand({batch_size, num_heads, seq_len, seq_len}); + } + + // 4D masks and other formats are already in correct format + return mask; +} // Computes scaled dot product attention on query, key and value tensors, using // an optional attention mask if passed, and applying dropout if a probability // greater than 0.0 is specified. @@ -733,8 +781,15 @@ Tensor scaled_dot_product_attention( } const auto query_device_type = query_.device().type(); const auto backend = static_cast(choice_int); + + auto hip_expanded_mask = expand_boolean_mask_for_hip(attn_mask_, query_); + std::optional hip_expanded_mask_opt; + if (hip_expanded_mask.defined()) { + hip_expanded_mask_opt = hip_expanded_mask; + } const auto convert_attn_func = backend != SDPBackend::cudnn_attention ? convert_boolean_attn_mask : convert_boolean_attn_mask_cudnn; - auto attn_mask = convert_attn_func(attn_mask_, query_.dtype()); + auto attn_mask = convert_attn_func(hip_expanded_mask_opt, query_.dtype()); + switch (backend) { case SDPBackend::cudnn_attention: { bool compute_logsumexp = should_compute_logsumexp(query_, key, value);