Skip to content

Commit

Permalink
Support FlashAttention in SPMD (#5)
Browse files Browse the repository at this point in the history
* feat: add spmd flash attention op

* style: format code
  • Loading branch information
lausannel authored Aug 8, 2024
1 parent 69921ba commit 630e333
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 1 deletion.
2 changes: 1 addition & 1 deletion torchacc/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .flash_attn import (flash_attn_varlen_qkvpacked_xla, flash_attn_varlen_xla,
flash_attn_xla)
flash_attn_xla, spmd_flash_attn_varlen_xla)
from .scaled_dot_product_attention import scaled_dot_product_attention
147 changes: 147 additions & 0 deletions torchacc/ops/flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import einops
import torch
import torch_xla
import torch_xla.distributed.spmd as xs


class FlashAttnVarlenQKVPackedXla(torch.autograd.Function):
Expand Down Expand Up @@ -55,6 +56,112 @@ def backward(ctx, dout, *args):
return dqkv, None, None, None, None, None, None, None, None, None


class SPMDFlashAttnVarlenXla(torch.autograd.Function):

@staticmethod
def forward(ctx,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
mesh=None,
partition_spec=None):
if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5)
assert isinstance(window_size, tuple) and len(window_size) == 2

ctx.partition_spec = partition_spec
ctx.mesh = mesh
ctx.q_full_shape = None
ctx.k_full_shape = None # for GQA

full_q = q
full_k = k
full_v = v
if partition_spec is not None:
ctx.q_full_shape = q.shape
ctx.k_full_shape = k.shape
q = xs.enable_manual_sharding(
q, partition_spec, mesh=mesh).global_tensor
k = xs.enable_manual_sharding(
k, partition_spec, mesh=mesh).global_tensor
v = xs.enable_manual_sharding(
v, partition_spec, mesh=mesh).global_tensor

with torch.no_grad():
softmax_lse, out, rng_state = torch_xla._XLAC._flash_attention_forward(
q, k, v, cu_seqlens_q, cu_seqlens_k, alibi_slopes, max_seqlen_q,
max_seqlen_k, dropout_p, softmax_scale, False, causal,
window_size[0], window_size[1], return_softmax, None)

if partition_spec is not None:
out = xs.disable_manual_sharding(
out, partition_spec, ctx.q_full_shape, mesh=mesh).global_tensor

out = out.to(q.dtype)

ctx.save_for_backward(full_q, full_k, full_v, out, softmax_lse,
cu_seqlens_q, cu_seqlens_k, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse)

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors

partition_spec = ctx.partition_spec
mesh = ctx.mesh

if partition_spec is not None:
q = xs.enable_manual_sharding(
q, partition_spec, mesh=mesh).global_tensor
k = xs.enable_manual_sharding(
k, partition_spec, mesh=mesh).global_tensor
v = xs.enable_manual_sharding(
v, partition_spec, mesh=mesh).global_tensor
dout = xs.enable_manual_sharding(
dout, partition_spec, mesh=mesh).global_tensor
out = xs.enable_manual_sharding(
out, partition_spec, mesh=mesh).global_tensor

dq, dk, dv, softmax_d = torch_xla._XLAC._flash_attention_backward(
dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k,
ctx.alibi_slopes, ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p,
ctx.softmax_scale, False, ctx.causal, ctx.window_size[0],
ctx.window_size[1], ctx.deterministic, None, rng_state)

if partition_spec is not None:
dq = xs.disable_manual_sharding(
dq, partition_spec, ctx.q_full_shape, mesh=mesh).global_tensor
dk = xs.disable_manual_sharding(
dk, partition_spec, ctx.k_full_shape, mesh=mesh).global_tensor
dv = xs.disable_manual_sharding(
dv, partition_spec, ctx.k_full_shape, mesh=mesh).global_tensor

dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., :dout.shape[-1]]
dv = dv[..., :dout.shape[-1]]

return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None


class FlashAttnVarlenXla(torch.autograd.Function):

@staticmethod
Expand Down Expand Up @@ -185,6 +292,46 @@ def flash_attn_varlen_qkvpacked_xla(
)


def spmd_flash_attn_varlen_xla(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
mesh=None,
partition_spec=None,
):
assert q.dtype in [torch.bfloat16,
torch.float16], 'flash attention only supports fp16/bf16'
return SPMDFlashAttnVarlenXla.apply(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
mesh,
partition_spec,
)


def flash_attn_varlen_xla(
q,
k,
Expand Down

0 comments on commit 630e333

Please sign in to comment.