Skip to content

2D/3D Mask Attention issue #2257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion aten/src/ATen/native/transformers/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,54 @@ Tensor _safe_softmax(
// reuse storage for out
return at::where_out(out, masked_rows, zero, out);
}

Tensor expand_boolean_mask_for_hip(
const std::optional<Tensor>& attn_mask,
const Tensor& query) {

if (!attn_mask.has_value() || !attn_mask->defined()) {
return {};
}

const auto& mask = attn_mask.value();

// Only expand boolean masks on HIP backend
if (!query.device().is_hip() || mask.dtype() != at::kBool) {
return mask;
}

const auto query_sizes = query.sizes();
const auto batch_size = query_sizes[0];
const auto num_heads = query_sizes[1];
const auto seq_len = query_sizes[2];

// Handle 2D boolean mask: [batch, seq_len] -> [batch, heads, seq_len, seq_len]
if (mask.dim() == 2) {
TORCH_CHECK(mask.size(0) == batch_size,
"Attention mask batch size (", mask.size(0),
") doesn't match query batch size (", batch_size, ")");
TORCH_CHECK(mask.size(1) == seq_len,
"Attention mask sequence length (", mask.size(1),
") doesn't match query sequence length (", seq_len, ")");

return mask.unsqueeze(1).unsqueeze(1).expand({batch_size, num_heads, seq_len, seq_len});
}

// Handle 3D boolean mask: [batch, seq_len, seq_len] -> [batch, heads, seq_len, seq_len]
if (mask.dim() == 3) {
TORCH_CHECK(mask.size(0) == batch_size,
"Attention mask batch size (", mask.size(0),
") doesn't match query batch size (", batch_size, ")");
TORCH_CHECK(mask.size(1) == seq_len && mask.size(2) == seq_len,
"Attention mask sequence dimensions ([", mask.size(1), ", ", mask.size(2),
"]) don't match query sequence length (", seq_len, ")");

return mask.unsqueeze(1).expand({batch_size, num_heads, seq_len, seq_len});
}

// 4D masks and other formats are already in correct format
return mask;
}
// Computes scaled dot product attention on query, key and value tensors, using
// an optional attention mask if passed, and applying dropout if a probability
// greater than 0.0 is specified.
Expand Down Expand Up @@ -733,8 +781,15 @@ Tensor scaled_dot_product_attention(
}
const auto query_device_type = query_.device().type();
const auto backend = static_cast<SDPBackend>(choice_int);

auto hip_expanded_mask = expand_boolean_mask_for_hip(attn_mask_, query_);
std::optional<Tensor> hip_expanded_mask_opt;
if (hip_expanded_mask.defined()) {
hip_expanded_mask_opt = hip_expanded_mask;
}
const auto convert_attn_func = backend != SDPBackend::cudnn_attention ? convert_boolean_attn_mask : convert_boolean_attn_mask_cudnn;
auto attn_mask = convert_attn_func(attn_mask_, query_.dtype());
auto attn_mask = convert_attn_func(hip_expanded_mask_opt, query_.dtype());

switch (backend) {
case SDPBackend::cudnn_attention: {
bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
Expand Down