From 630e333d79ccc1ad481696185a255a00bb8b1e66 Mon Sep 17 00:00:00 2001 From: Zhan Lu <51200935+lausannel@users.noreply.github.com> Date: Thu, 8 Aug 2024 15:58:23 +0800 Subject: [PATCH] Support FlashAttention in SPMD (#5) * feat: add spmd flash attention op * style: format code --- torchacc/ops/__init__.py | 2 +- torchacc/ops/flash_attn.py | 147 +++++++++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 1 deletion(-) diff --git a/torchacc/ops/__init__.py b/torchacc/ops/__init__.py index 97432cf..a026e68 100644 --- a/torchacc/ops/__init__.py +++ b/torchacc/ops/__init__.py @@ -1,3 +1,3 @@ from .flash_attn import (flash_attn_varlen_qkvpacked_xla, flash_attn_varlen_xla, - flash_attn_xla) + flash_attn_xla, spmd_flash_attn_varlen_xla) from .scaled_dot_product_attention import scaled_dot_product_attention diff --git a/torchacc/ops/flash_attn.py b/torchacc/ops/flash_attn.py index f6f03bd..910668b 100644 --- a/torchacc/ops/flash_attn.py +++ b/torchacc/ops/flash_attn.py @@ -1,6 +1,7 @@ import einops import torch import torch_xla +import torch_xla.distributed.spmd as xs class FlashAttnVarlenQKVPackedXla(torch.autograd.Function): @@ -55,6 +56,112 @@ def backward(ctx, dout, *args): return dqkv, None, None, None, None, None, None, None, None, None +class SPMDFlashAttnVarlenXla(torch.autograd.Function): + + @staticmethod + def forward(ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + mesh=None, + partition_spec=None): + if softmax_scale is None: + softmax_scale = q.shape[-1]**(-0.5) + assert isinstance(window_size, tuple) and len(window_size) == 2 + + ctx.partition_spec = partition_spec + ctx.mesh = mesh + ctx.q_full_shape = None + ctx.k_full_shape = None # for GQA + + full_q = q + full_k = k + full_v = v + if partition_spec is not None: + ctx.q_full_shape = q.shape + ctx.k_full_shape = k.shape + q = xs.enable_manual_sharding( + q, partition_spec, mesh=mesh).global_tensor + k = xs.enable_manual_sharding( + k, partition_spec, mesh=mesh).global_tensor + v = xs.enable_manual_sharding( + v, partition_spec, mesh=mesh).global_tensor + + with torch.no_grad(): + softmax_lse, out, rng_state = torch_xla._XLAC._flash_attention_forward( + q, k, v, cu_seqlens_q, cu_seqlens_k, alibi_slopes, max_seqlen_q, + max_seqlen_k, dropout_p, softmax_scale, False, causal, + window_size[0], window_size[1], return_softmax, None) + + if partition_spec is not None: + out = xs.disable_manual_sharding( + out, partition_spec, ctx.q_full_shape, mesh=mesh).global_tensor + + out = out.to(q.dtype) + + ctx.save_for_backward(full_q, full_k, full_v, out, softmax_lse, + cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + + partition_spec = ctx.partition_spec + mesh = ctx.mesh + + if partition_spec is not None: + q = xs.enable_manual_sharding( + q, partition_spec, mesh=mesh).global_tensor + k = xs.enable_manual_sharding( + k, partition_spec, mesh=mesh).global_tensor + v = xs.enable_manual_sharding( + v, partition_spec, mesh=mesh).global_tensor + dout = xs.enable_manual_sharding( + dout, partition_spec, mesh=mesh).global_tensor + out = xs.enable_manual_sharding( + out, partition_spec, mesh=mesh).global_tensor + + dq, dk, dv, softmax_d = torch_xla._XLAC._flash_attention_backward( + dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, + ctx.alibi_slopes, ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, + ctx.softmax_scale, False, ctx.causal, ctx.window_size[0], + ctx.window_size[1], ctx.deterministic, None, rng_state) + + if partition_spec is not None: + dq = xs.disable_manual_sharding( + dq, partition_spec, ctx.q_full_shape, mesh=mesh).global_tensor + dk = xs.disable_manual_sharding( + dk, partition_spec, ctx.k_full_shape, mesh=mesh).global_tensor + dv = xs.disable_manual_sharding( + dv, partition_spec, ctx.k_full_shape, mesh=mesh).global_tensor + + dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., :dout.shape[-1]] + dv = dv[..., :dout.shape[-1]] + + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None + + class FlashAttnVarlenXla(torch.autograd.Function): @staticmethod @@ -185,6 +292,46 @@ def flash_attn_varlen_qkvpacked_xla( ) +def spmd_flash_attn_varlen_xla( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + mesh=None, + partition_spec=None, +): + assert q.dtype in [torch.bfloat16, + torch.float16], 'flash attention only supports fp16/bf16' + return SPMDFlashAttnVarlenXla.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + mesh, + partition_spec, + ) + + def flash_attn_varlen_xla( q, k,