Skip to content

Commit

Permalink
Re-applying G42 bias triton fix on 0.4.3 (#41)
Browse files Browse the repository at this point in the history
* Using rocm_flash_attention that supports bias computed from alibi slopes; Using attn_fwd triton kernel from ROCm/triton main_perf that does not cause triton compolier to hang

* Uninitialized variable fix
  • Loading branch information
gshtras authored Jun 6, 2024
1 parent e4af60b commit a822875
Show file tree
Hide file tree
Showing 2 changed files with 251 additions and 353 deletions.
87 changes: 79 additions & 8 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,62 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
)
return self._cached_decode_metadata

def _make_alibi_bias(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: List[int],
) -> List[torch.Tensor]:
attn_biases = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]

num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1)).to(alibi_slopes.device)
bias.mul_(alibi_slopes[:, None, None])
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(alibi_slopes.device)
attn_biases.append((bias + inf_mask).to(dtype))

return attn_biases


def _make_alibi_bias_v2(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: List[int],
make_attn_mask: bool = True
) -> List[torch.Tensor]:
attn_biases = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]

num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1)).to(alibi_slopes.device)
bias.mul_(alibi_slopes[:, None, None])
if make_attn_mask:
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(alibi_slopes.device)
attn_biases.append((bias + inf_mask).to(dtype))
else:
attn_biases.append(bias.to(dtype))

return attn_biases



class ROCmFlashAttentionImpl(AttentionImpl):
"""
Expand Down Expand Up @@ -324,7 +380,12 @@ def forward(
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
att_masks = None
if self.use_triton_flash_attn:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias_v2(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens, make_attn_mask=False) # type: ignore
out, _ = self.attn_func(
query,
key,
Expand All @@ -336,8 +397,13 @@ def forward(
prefill_meta.max_prefill_seq_len,
True,
self.scale,
att_masks[0][None] if att_masks is not None else None,
)
elif self.use_naive_attn:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias_v2(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens, make_attn_mask=True) # type: ignore
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
Expand All @@ -348,6 +414,7 @@ def forward(
value,
prefill_meta.seq_lens,
self.scale,
att_masks
)
else:
out = self.attn_func(
Expand Down Expand Up @@ -408,16 +475,18 @@ def _naive_attention(
value: torch.Tensor,
seq_lens: List[int],
scale: float,
attn_masks: Optional[List[torch.Tensor]],
) -> torch.Tensor:
output = torch.empty_like(query)
start = 0
for _, seq_len in enumerate(seq_lens):
for i, seq_len in enumerate(seq_lens):
end = start + seq_len
out = _naive_masked_attention(
query[start:end],
key[start:end],
value[start:end],
scale,
attn_masks[i],
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out)
Expand All @@ -431,16 +500,18 @@ def _naive_masked_attention(
key: torch.Tensor,
value: torch.Tensor,
scale: float,
attn_mask: Optional[torch.Tensor],
) -> torch.Tensor:
seq_len, head_size, head_dim = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
dtype=query.dtype,
device=query.device),
diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min
if attn_mask is None:
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
dtype=query.dtype,
device=query.device),
diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out
return out
Loading

0 comments on commit a822875

Please sign in to comment.