diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 7a4d953840..1fae9e99f2 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -163,12 +163,10 @@ def run_dpa_with_cp( torch.tensor([q_input_shape[0]], dtype=torch.int32), ] ).cuda() - if kernel_backend == "FlashAttention": - cu_seqlens_q = cu_seqlens_q_padded[:-1] - else: - cu_seqlens_q = torch.cat( - [torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0, dtype=torch.int32)] - ).cuda() + cu_seqlens_q = torch.clone(cu_seqlens_q_padded) + if kernel_backend == "FusedAttention": + cu_seqlens_q[1:-1] = seqlens_q.cumsum(0, dtype=torch.int32).cuda() + cu_seqlens_q[-1] = cu_seqlens_q[-2] cu_seqlens_kv = cu_seqlens_q cu_seqlens_kv_padded = cu_seqlens_q_padded else: @@ -204,10 +202,8 @@ def run_dpa_with_cp( core_attention_bias=bias, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], - cu_seqlens_kv_padded=( - None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] - ), + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) if fp8_mha: dout_fp8 = Float8Tensor.to_float8(dout, fp8_dtype=tex.DType.kFloat8E5M2) @@ -276,10 +272,8 @@ def run_dpa_with_cp( core_attention_bias=bias_, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], - cu_seqlens_kv_padded=( - None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] - ), + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) if fp8_mha: dout_fp8_ = Float8Tensor.to_float8(dout_, fp8_dtype=tex.DType.kFloat8E5M2) @@ -311,7 +305,7 @@ def run_dpa_with_cp( dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]] dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_] - cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size + cu_seqlens_q_padded = cu_seqlens_q_padded // world_size cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True ) @@ -327,7 +321,7 @@ def run_dpa_with_cp( ).item() == 0 ) - cu_seqlens_kv_padded = cu_seqlens_kv_padded[:-1] // world_size + cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size cu_seqlens_kv = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True ) diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 73994e1873..9866591e8d 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -121,22 +121,14 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") if dtype == "fp8" and get_device_compute_capability() < (9, 0): pytest.skip("FP8 attention is only supported on sm90+!") - if qkv_format == "thd" and get_cudnn_version() >= (9, 6, 0): - pytest.skip("THD format is not supported for cuDNN 9.6+!") config = model_configs_fused_attn[model] - if qkv_format == "thd" and config.num_heads != config.num_gqa_groups: - pytest.skip("THD format does not support QGA/MQA yet!") if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": pytest.skip("THD format does not support post_scale_bias yet!") if qkv_format == "thd" and cp_comm_type == "all_gather": pytest.skip("CP implementation with KV all-gather does not support THD format yet!") if qkv_format == "thd" and "a2a" in cp_comm_type: pytest.skip("CP implementation with QKVO A2A does not support THD format yet!") - if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a": - pytest.skip( - "Sliding window attention only can be supported with the implementation of QKVO A2A!" - ) if dtype == "fp8" and cp_comm_type == "all_gather": pytest.skip( "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" @@ -147,6 +139,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("FP8 attention cannot work with bias yet!") if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("FP8 attention cannot work with sliding window yet!") + if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): + pytest.skip("CP implementation with KV P2P does not support sliding window yet!") if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": pytest.skip("CP implementation with KV all-gather does not support bias yet!") if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9f08f67304..3f0267affb 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -125,11 +125,13 @@ def _get_supported_versions(version_min, version_max): _flash_attn_2_5_7_plus = False _flash_attn_2_6_0_plus = False +flash_attn_cuda_bwd = None flash_attn_func = None flash_attn_varlen_func = None -flash_attn_varlen_fwd = None -flash_attn_varlen_bwd = None -flash_attn_cuda_bwd = None +_flash_attn_fwd = None +_flash_attn_bwd = None +_flash_attn_varlen_fwd = None +_flash_attn_varlen_bwd = None try: _flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) @@ -141,14 +143,16 @@ def _get_supported_versions(version_min, version_max): ) else: if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version: + from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd + from flash_attn.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd from flash_attn.flash_attn_interface import ( - _flash_attn_varlen_forward as flash_attn_varlen_fwd, + _flash_attn_varlen_forward as _flash_attn_varlen_fwd, ) from flash_attn.flash_attn_interface import ( - _flash_attn_varlen_backward as flash_attn_varlen_bwd, + _flash_attn_varlen_backward as _flash_attn_varlen_bwd, ) - from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd _flash_attn_is_installed = True _flash_attn_2_plus = _flash_attn_version >= PkgVersion("2") @@ -195,11 +199,13 @@ def _get_supported_versions(version_min, version_max): from flashattn_hopper.flash_attn_interface import ( flash_attn_varlen_func as flash_attn_varlen_func_v3, ) + from flashattn_hopper.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 + from flashattn_hopper.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 from flashattn_hopper.flash_attn_interface import ( - _flash_attn_varlen_forward as flash_attn_varlen_fwd_v3, + _flash_attn_varlen_forward as _flash_attn_varlen_fwd_v3, ) from flashattn_hopper.flash_attn_interface import ( - _flash_attn_varlen_backward as flash_attn_varlen_bwd_v3, + _flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3, ) _flash_attn_3_is_installed = True @@ -602,12 +608,6 @@ def get_attention_backend( "Disabling FusedAttention as it does not support context parallelism with MLA" ) use_fused_attention = False - elif cudnn_version >= (9, 6, 0) and qkv_format == "thd": - logger.debug( - "Disabling FusedAttention as it does not support context parallelism with THD for" - " cuDNN 9.6+" - ) - use_fused_attention = False # Filter: Attention mask # attn_mask_type | attention_mask | supported backends @@ -1804,12 +1804,20 @@ def forward( else: qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - pad_between_seqs_q = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q) - pad_between_seqs_kv = not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv) + pad_between_seqs_q = cu_seqlens_q_padded is not None and not torch.equal( + cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1] + ) + pad_between_seqs_kv = cu_seqlens_kv_padded is not None and not torch.equal( + cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1] + ) max_seqlen_q = max_seqlen_q // cp_size max_seqlen_kv = max_seqlen_kv // cp_size - cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size - cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size + cu_seqlens_q_padded = ( + None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // cp_size + ) + cu_seqlens_kv_padded = ( + None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // cp_size + ) cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)] @@ -1882,9 +1890,6 @@ def forward( elif qkv_format == "sbhd": # [s, b, np, hn] -> [2, s//2, b, np, hn] q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]] - total_tokens_kv = None if qkv_format != "thd" else k.shape[0] - # remove padded tokens at the end - k, v = [x if qkv_format != "thd" else x[: cu_seqlens_kv_padded[-1]] for x in [k, v]] if attn_bias is not None: assert len(attn_bias.shape) == 4, ( "Only support bias shape of [b, h, sq, sk] for forward, " @@ -1907,17 +1912,27 @@ def forward( ) assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8" - softmax_lse_in_packed_format = not use_fused_attention and ( - _flash_attn_2_6_0_plus or _use_flash_attn_3 - ) + softmax_lse_in_packed_format = False + if qkv_format == "thd": + if use_fused_attention: + softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0) + else: + softmax_lse_in_packed_format = _flash_attn_2_6_0_plus or _use_flash_attn_3 + flash_attn_fwd = None if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} if _use_flash_attn_3: - flash_attn_fwd = flash_attn_varlen_fwd_v3 + if qkv_format == "thd": + flash_attn_fwd = _flash_attn_varlen_fwd_v3 + else: + flash_attn_fwd = _flash_attn_fwd_v3 fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) else: - flash_attn_fwd = flash_attn_varlen_fwd + if qkv_format == "thd": + flash_attn_fwd = _flash_attn_varlen_fwd + else: + flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False if _flash_attn_2_3_plus: @@ -1943,7 +1958,7 @@ def forward( fwd_results_correction_done = torch.cuda.Event() p2p_comm_buffers = [None for _ in range(cp_size)] - if use_fused_attention and qkv_format in ["bshd", "sbhd"]: + if qkv_format in ["bshd", "sbhd"]: p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) else: p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) @@ -1991,31 +2006,31 @@ def forward( cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True ) - else: + elif use_fused_attention or qkv_format == "thd": cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size if pad_between_seqs_kv: cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True ) - else: + elif use_fused_attention or qkv_format == "thd": cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + k.shape[0], -1, 2, *k.shape[-2:] + ) + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) + elif qkv_format == "thd": + q_inputs[i % 2] = q if use_fused_attention: - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) - elif qkv_format == "thd": - q_inputs[i % 2] = q if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = torch.cat( @@ -2060,18 +2075,27 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) + fa_forward_args_thd = [] + if qkv_format == "thd": + fa_forward_args_thd = [ + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + max_seqlen_q, + max_seqlen_kv, + ] fa_outputs = flash_attn_fwd( q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, causal=True, **fa_forward_kwargs, ) @@ -2084,7 +2108,7 @@ def forward( cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True ) - else: + elif use_fused_attention or qkv_format == "thd": cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size if pad_between_seqs_kv: cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -2095,25 +2119,26 @@ def forward( True, False, ) - else: + elif use_fused_attention or qkv_format == "thd": cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2) + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...] + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][0] + elif qkv_format == "thd": + q_inputs[i % 2] = q + # [2, t, np, hn] -> [2, t/2, np, hn] + kv_inputs[i % 2] = tex.thd_read_half_tensor( + kv_inputs[i % 2], cu_seqlens_kv_padded, 0 + ) if use_fused_attention: - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous() - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][0].contiguous() - elif qkv_format == "thd": - q_inputs[i % 2] = q - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_inputs[i % 2] = tex.thd_read_half_tensor( - kv_inputs[i % 2], cu_seqlens_kv_padded, 0 - ) + kv_inputs[i % 2] = kv_inputs[i % 2].contiguous() if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() @@ -2156,28 +2181,29 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) + fa_forward_args_thd = [] if qkv_format == "thd": - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_inputs[i % 2] = tex.thd_read_half_tensor( - kv_inputs[i % 2], cu_seqlens_kv_padded, 0 - ) - else: - # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous() - # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) + fa_forward_args_thd = [ + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + max_seqlen_q, + max_seqlen_kv // 2, + ] if _use_flash_attn_3 or _flash_attn_2_3_plus: fa_forward_kwargs["window_size"] = (-1, -1) fa_outputs = flash_attn_fwd( q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv // 2, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, causal=False, **fa_forward_kwargs, ) @@ -2190,7 +2216,7 @@ def forward( cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True ) - else: + elif use_fused_attention or qkv_format == "thd": cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2) if pad_between_seqs_kv: cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -2201,28 +2227,29 @@ def forward( True, True, ) - else: + elif use_fused_attention or qkv_format == "thd": cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + q_inputs[i % 2] = q[:, 1, ...] + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + k.shape[0], -1, 2, *k.shape[-2:] + ) + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_inputs[i % 2] = q[1] + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) + elif qkv_format == "thd": + # [t, np, hn] -> [t/2, np, hn] + q_inputs[i % 2] = tex.thd_read_half_tensor( + q, cu_seqlens_q_padded, 1 + ) if use_fused_attention: - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_inputs[i % 2] = q[:, 1, ...].contiguous() - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_inputs[i % 2] = q[1].contiguous() - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) - elif qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_inputs[i % 2] = tex.thd_read_half_tensor( - q, cu_seqlens_q_padded, 1 - ) + q_inputs[i % 2] = q_inputs[i % 2].contiguous() if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = torch.cat( @@ -2271,28 +2298,29 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + fa_forward_args_thd = [] if qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_inputs[i % 2] = tex.thd_read_half_tensor( - q, cu_seqlens_q_padded, 1 - ) - else: - # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn] - q_inputs[i % 2] = ( - q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) - ) - # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) + fa_forward_args_thd = [ + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + max_seqlen_q // 2, + max_seqlen_kv, + ] if _use_flash_attn_3 or _flash_attn_2_3_plus: fa_forward_kwargs["window_size"] = (-1, -1) fa_outputs = flash_attn_fwd( q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - max_seqlen_q // 2, - max_seqlen_kv, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, causal=False, **fa_forward_kwargs, ) @@ -2305,7 +2333,7 @@ def forward( cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True ) - else: + elif use_fused_attention or qkv_format == "thd": cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size if pad_between_seqs_kv: cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( @@ -2316,7 +2344,7 @@ def forward( True, True, ) - else: + elif use_fused_attention or qkv_format == "thd": cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size if use_fused_attention: if attn_bias is not None: @@ -2363,18 +2391,27 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: - # [b, sq, np, hn] -> [b*sq, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) - # [2, b, sk, np, hn] -> [2, b*sk, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) + fa_forward_args_thd = [] + if qkv_format == "thd": + fa_forward_args_thd = [ + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + max_seqlen_q, + max_seqlen_kv, + ] fa_outputs = flash_attn_fwd( - q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv, + q, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, causal=False, **fa_forward_kwargs, ) @@ -2389,13 +2426,13 @@ def forward( flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done) if use_fused_attention: - # [b, np, sq, 1] -> [b, np, sq] + # [b, np, sq, 1] -> [b, np, sq] or + # [t, np, 1] -> [t, np] softmax_lse_per_step[i - 1].squeeze_(-1) - if qkv_format != "thd" and softmax_lse_in_packed_format: - # [np, t] -> [np, b, sq] - softmax_lse_per_step[i - 1] = softmax_lse_per_step[i - 1].view( - q.shape[-2], q.shape[0], -1 - ) + if softmax_lse_in_packed_format: + softmax_lse_per_step[i - 1] = ( + softmax_lse_per_step[i - 1].transpose(0, 1).contiguous() + ) with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if fp8: @@ -2410,8 +2447,7 @@ def forward( out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) if causal and qkv_format != "thd": - # [b, np, sq] -> [b, np, 2, sq//2] lse not in packed format - # [np, b, sq] -> [np, b, 2, sq//2] lse in packed format + # [b, np, sq] -> [b, np, 2, sq//2] softmax_lse_ = softmax_lse.view( *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2 ) @@ -2439,16 +2475,6 @@ def forward( softmax_lse = softmax_lse.to(torch.float) for i in range(cp_size): - out_ = None - if qkv_format == "bshd": - out_per_step[i] = out_per_step[i].view( - out.shape[0], -1, *out.shape[-2:] - ) # pylint: disable=used-before-assignment - out_ = out[:, 1, ...] - elif qkv_format == "sbhd": - out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:]) - out_ = out[1] - if i <= rank or not causal: if qkv_format in ["bshd", "sbhd"]: flash_attn_fwd_out_correction( @@ -2471,6 +2497,7 @@ def forward( ) else: if qkv_format in ["bshd", "sbhd"]: + out_ = out.select(seq_dim, 1) flash_attn_fwd_out_correction( out_, out_per_step[i], @@ -2490,9 +2517,6 @@ def forward( softmax_lse_in_packed_format, ) - if qkv_format != "thd" and softmax_lse_in_packed_format: - # [np, b, sq] -> [np, t] - softmax_lse = softmax_lse.view(softmax_lse.shape[0], -1) kv = p2p_comm_buffers[-1] if qkv_format == "bshd": out = out.view(out.shape[0], -1, *out.shape[-2:]) @@ -2587,7 +2611,6 @@ def forward( ctx.cp_global_ranks = cp_global_ranks ctx.cp_stream = cp_stream ctx.dropout_p = dropout_p - ctx.total_tokens_kv = total_tokens_kv ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale @@ -2597,6 +2620,7 @@ def forward( ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention + ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 @@ -2646,14 +2670,10 @@ def backward(ctx, dout): attn_dbias = None attn_dbias_ = None - softmax_lse_in_packed_format = not ctx.use_fused_attention and ( - _flash_attn_2_6_0_plus or _use_flash_attn_3 - ) - if causal: - if ctx.qkv_format == "thd" or softmax_lse_in_packed_format: + if ctx.qkv_format == "thd": softmax_lse_ = tex.thd_read_second_half_lse( - softmax_lse, cu_seqlens_q_padded, softmax_lse_in_packed_format + softmax_lse, cu_seqlens_q_padded, ctx.softmax_lse_in_packed_format ) else: # [b, np, sq] -> [b, np, 2, sq//2] @@ -2661,13 +2681,20 @@ def backward(ctx, dout): *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2 ) softmax_lse_ = softmax_lse_[..., 1, :].contiguous() - if ctx.use_fused_attention: - # [b, np, sq//2] -> [b, np, sq//2, 1] - softmax_lse_.unsqueeze_(-1) + if ctx.use_fused_attention: + if ctx.softmax_lse_in_packed_format: + softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous() + # [b, np, sq//2] -> [b, np, sq//2, 1] or + # [t//2, np] -> [t//2, np, 1] + softmax_lse_.unsqueeze_(-1) if ctx.use_fused_attention: - # [b, np, sq] -> [b, np, sq, 1] + if ctx.softmax_lse_in_packed_format: + softmax_lse = softmax_lse.transpose(0, 1).contiguous() + # [b, np, sq] -> [b, np, sq, 1] or + # [t, np] -> [t, np, 1] softmax_lse.unsqueeze_(-1) + dq = None dout_dtype = dout.dtype fused_attn_backend = None fused_attn_qkv_dtype = None @@ -2715,8 +2742,6 @@ def backward(ctx, dout): dout_scale_inv = dout._scale_inv dout = dout._data dq = torch.empty_like(q) - if ctx.qkv_format == "thd" and causal: - dq[cu_seqlens_q_padded[-1] :].fill_(0) p2p_comm_buffers = [ torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), @@ -2760,10 +2785,16 @@ def backward(ctx, dout): if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} if _use_flash_attn_3: - flash_attn_bwd = flash_attn_varlen_bwd_v3 + if ctx.qkv_format == "thd": + flash_attn_bwd = _flash_attn_varlen_bwd_v3 + else: + flash_attn_bwd = _flash_attn_bwd_v3 fa_backward_kwargs["deterministic"] = ctx.deterministic else: - flash_attn_bwd = flash_attn_varlen_bwd + if ctx.qkv_format == "thd": + flash_attn_bwd = _flash_attn_varlen_bwd + else: + flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p if _flash_attn_2_4_plus: fa_backward_kwargs["alibi_slopes"] = None @@ -2808,32 +2839,29 @@ def backward(ctx, dout): ) kv = p2p_comm_buffers[i % 2][0] - dk_, dv_ = None, None + q_, kv_, out_, dout_ = None, None, None, None + dq_, dk_, dv_ = None, None, None if ctx.fp8 and ctx.use_fused_attention: fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i] fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i] # In reversed order of fwd if causal: if i == (cp_size - 1): + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_, out_, dout_ = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] + ] + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) + elif ctx.qkv_format == "thd": + q_, kv_, out_, dout_ = q, kv, out, dout if ctx.use_fused_attention: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_ = q.view(q.shape[0], -1, *q.shape[-2:]) - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - out_ = out.view(out.shape[0], -1, *out.shape[-2:]) - dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_ = q.view(-1, *q.shape[-3:]) - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - out_ = out.view(-1, *out.shape[-3:]) - dout_ = dout.view(-1, *dout.shape[-3:]) - elif ctx.qkv_format == "thd": - q_, kv_, out_, dout_ = q, kv, out, dout if ctx.fp8: aux_ctx_tensors = [ softmax_lse, @@ -2869,15 +2897,16 @@ def backward(ctx, dout): **fp8_meta_kwargs, ) else: - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - q_ = q.view(-1, *q.shape[-2:]) - dq_ = torch.zeros_like(q_) - # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] - kv_ = kv.view(2, -1, *kv.shape[-2:]) + dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - out_ = out.view(-1, *out.shape[-2:]) - dout_ = dout.view(-1, *dout.shape[-2:]) + fa_backward_args_thd = [] + if ctx.qkv_format == "thd": + fa_backward_args_thd = [ + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + ] if _use_flash_attn_3 or _flash_attn_2_3_plus: fa_backward_kwargs["window_size"] = (-1, 0) if not _use_flash_attn_3: @@ -2885,42 +2914,36 @@ def backward(ctx, dout): flash_attn_bwd( dout_, q_, - kv_[0], - kv_[1], + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, softmax_lse, dq_, - dkv_[0], - dkv_[1], - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - ctx.max_seqlen_q, - ctx.max_seqlen_kv, + dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], + dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + *fa_backward_args_thd, causal=True, **fa_backward_kwargs, ) elif i >= (cp_size - rank - 1): + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_, out_, dout_ = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] + ] + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_ = kv[:, 0] + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_ = kv[0] + elif ctx.qkv_format == "thd": + q_, out_, dout_ = q, out, dout + # [2, t, np, hn] -> [2, t/2, np, hn] + kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) if ctx.use_fused_attention: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_ = q.view(q.shape[0], -1, *q.shape[-2:]) - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_ = kv[:, 0, ...].contiguous() - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - out_ = out.view(out.shape[0], -1, *out.shape[-2:]) - dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_ = q.view(-1, *q.shape[-3:]) - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_ = kv[0].contiguous() - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - out_ = out.view(-1, *out.shape[-3:]) - dout_ = dout.view(-1, *dout.shape[-3:]) - elif ctx.qkv_format == "thd": - q_, out_, dout_ = q, out, dout - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) + kv_ = kv_.contiguous() if ctx.fp8: aux_ctx_tensors = [ softmax_lse, @@ -2958,19 +2981,16 @@ def backward(ctx, dout): **fp8_meta_kwargs, ) else: - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - q_ = q.view(-1, *q.shape[-2:]) - dq_ = torch.zeros_like(q_) - if ctx.qkv_format == "thd": - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) - else: - # [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn] - kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:]) + dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - out_ = out.view(-1, *out.shape[-2:]) - dout_ = dout.view(-1, *dout.shape[-2:]) + fa_backward_args_thd = [] + if ctx.qkv_format == "thd": + fa_backward_args_thd = [ + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], + ctx.max_seqlen_q, + ctx.max_seqlen_kv // 2, + ] if _use_flash_attn_3 or _flash_attn_2_3_plus: fa_backward_kwargs["window_size"] = (-1, -1) if not _use_flash_attn_3: @@ -2978,44 +2998,37 @@ def backward(ctx, dout): flash_attn_bwd( dout_, q_, - kv_[0], - kv_[1], + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, softmax_lse, dq_, - dkv_[0], - dkv_[1], - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - ctx.max_seqlen_q, - ctx.max_seqlen_kv // 2, + dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], + dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + *fa_backward_args_thd, causal=False, **fa_backward_kwargs, ) else: + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1] + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_, out_, dout_ = q[1], out[1], dout[1] + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) + elif ctx.qkv_format == "thd": + # [t, np, hn] -> [t/2, np, hn] + q_, out_, dout_ = [ + tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1) + for x in [q, out, dout] + ] + kv_ = kv if ctx.use_fused_attention: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_ = q[:, 1, ...].contiguous() - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - out_ = out[:, 1, ...].contiguous() - dout_ = dout[:, 1, ...].contiguous() - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_ = q[1].contiguous() - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - out_ = out[1].contiguous() - dout_ = dout[1].contiguous() - elif ctx.qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1) - out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1) - dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1) - kv_ = kv + q_, out_, dout_ = [x.contiguous() for x in [q_, out_, dout_]] if ctx.fp8: aux_ctx_tensors = [ softmax_lse_, @@ -3053,23 +3066,16 @@ def backward(ctx, dout): **fp8_meta_kwargs, ) else: - if ctx.qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1) - else: - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] - q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) - dq_ = torch.zeros_like(q_) - # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] - kv_ = kv.view(2, -1, *kv.shape[-2:]) + dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) + fa_backward_args_thd = [] if ctx.qkv_format == "thd": - out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1) - dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1) - else: - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] - out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:]) - dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:]) + fa_backward_args_thd = [ + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], + ctx.max_seqlen_q // 2, + ctx.max_seqlen_kv, + ] if _use_flash_attn_3 or _flash_attn_2_3_plus: fa_backward_kwargs["window_size"] = (-1, -1) if not _use_flash_attn_3: @@ -3077,17 +3083,14 @@ def backward(ctx, dout): flash_attn_bwd( dout_, q_, - kv_[0], - kv_[1], + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, softmax_lse_, dq_, - dkv_[0], - dkv_[1], - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - ctx.max_seqlen_q // 2, - ctx.max_seqlen_kv, + dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], + dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + *fa_backward_args_thd, causal=False, **fa_backward_kwargs, ) @@ -3124,50 +3127,41 @@ def backward(ctx, dout): **fp8_meta_kwargs, ) else: - # [b, sq, np, hn] -> [b*sq, np, hn] - q_ = q.view(-1, *q.shape[-2:]) - dq_ = torch.zeros_like(q_) - # [2, b, sk, np, hn] -> [2, b*sk, np, hn] - kv_ = kv.view(2, -1, *kv.shape[-2:]) - dkv_ = torch.empty_like(kv_) - # [b, sq, np, hn] -> [b*sq, np, hn] - out_ = out.view(-1, *out.shape[-2:]) - dout_ = dout.view(-1, *dout.shape[-2:]) + dq_ = torch.empty_like(q) + dkv_ = torch.empty_like(kv) + fa_backward_args_thd = [] + if ctx.qkv_format == "thd": + fa_backward_args_thd = [ + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + ] if _use_flash_attn_3 or _flash_attn_2_3_plus: fa_backward_kwargs["window_size"] = (-1, -1) if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( - dout_, - q_, - kv_[0], - kv_[1], - out_, + dout, + q, + kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0], + kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], + out, softmax_lse, dq_, - dkv_[0], - dkv_[1], - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - ctx.max_seqlen_q, - ctx.max_seqlen_kv, + dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], + dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + *fa_backward_args_thd, causal=False, **fa_backward_kwargs, ) if ctx.fp8: dq = dq_fp8[(rank + i + 1) % cp_size] - if i >= (cp_size - rank - 1) or not causal: - # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal - # [b*sq, np, hn] -> [b, sq, np, hn] if not causal + if causal and ctx.qkv_format in ["bshd", "sbhd"] and i >= (cp_size - rank - 1): + # [b, sq, np, hn] -> [b, 2, sq//2, np, hn] or + # [sq, b, np, hn] -> [2, sq//2, b, np, hn] dq_ = dq_.view(*dq.shape) - else: - if ctx.qkv_format == "bshd": - # [b*sq//2, np, hn] -> [b, sq//2, np, hn] - dq_ = dq_.view(dq.shape[0], *dq.shape[2:]) - elif ctx.qkv_format == "sbhd": - # [b*sq//2, np, hn] -> [sq//2, b, np, hn] - dq_ = dq_.view(-1, *dq.shape[-3:]) if ctx.fp8: if i >= (cp_size - rank - 1) or not causal: @@ -3242,24 +3236,21 @@ def backward(ctx, dout): else: dkv = p2p_comm_buffers[(i + 1) % 2][1] if ctx.use_fused_attention: - dkv_ = torch.cat( - (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0 - ) # pylint: disable=used-before-assignment if ctx.qkv_format in ["bshd", "sbhd"]: - # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or - # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] - dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) - if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): - if ctx.qkv_format == "bshd": - # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn] - dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:]) - elif ctx.qkv_format == "sbhd": - # [2, b*sk//2, np, hn] -> [2, sk//2, b, np, hn] - dkv_ = dkv_.view(dkv.shape[0], -1, *dkv.shape[-3:]) - else: - # [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal - # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal - dkv_ = dkv_.view(*dkv.shape) + dkv_ = _combine_tensors([dk_, dv_], -2) + elif ctx.qkv_format == "thd": + dkv_ = torch.cat( + (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0 + ) # pylint: disable=used-before-assignment + if ctx.qkv_format in ["bshd", "sbhd"]: + # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or + # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] + dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) + dkv_ = dkv_.movedim(-3, 0) + if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): + # [2, b, sk, np, hn] -> [2, b, 2, sk//2, np, hn] or + # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn] + dkv_ = dkv_.view(*dkv.shape) if ctx.fp8: if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): @@ -3341,13 +3332,9 @@ def backward(ctx, dout): # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:]) - if ctx.qkv_format == "thd": - dkv_ = torch.empty( - 2, ctx.total_tokens_kv, *dkv.shape[-2:], dtype=dkv.dtype, device=dkv.device - ) - dkv_[:, : cu_seqlens_kv_padded[-1]].copy_(dkv) - dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0) - dkv = dkv_ + if ctx.qkv_format == "thd" and not ctx.use_fused_attention: + dq[cu_seqlens_q_padded[-1] :].fill_(0) + dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: dq, dkv = [ @@ -3494,9 +3481,15 @@ def forward( if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} if _use_flash_attn_3: - flash_attn_fwd = flash_attn_varlen_fwd_v3 + if qkv_format == "thd": + flash_attn_fwd = _flash_attn_varlen_fwd_v3 + else: + flash_attn_fwd = _flash_attn_fwd_v3 else: - flash_attn_fwd = flash_attn_varlen_fwd + if qkv_format == "thd": + flash_attn_fwd = _flash_attn_varlen_fwd + else: + flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False if _flash_attn_2_4_plus: @@ -3514,8 +3507,11 @@ def forward( max_seqlen_q = max_seqlen_q // (2 * cp_size) max_seqlen_kv = max_seqlen_kv // (2 * cp_size) - cu_seqlens_q = cu_seqlens_q // (2 * cp_size) - cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) + if use_fused_attention or qkv_format == "thd": + cu_seqlens_q = cu_seqlens_q // (2 * cp_size) + cu_seqlens_q_padded = ( + None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // (2 * cp_size) + ) # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn] q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) @@ -3570,9 +3566,10 @@ def forward( kv_seq_range_per_step[i][1], ) max_seqlen_kv_ = seq_end_idx - seq_start_idx - cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens( - k.shape[1], max_seqlen_kv_, k.device - ) + if use_fused_attention or qkv_format == "thd": + cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens( + k.shape[1], max_seqlen_kv_, k.device + ) k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] @@ -3599,15 +3596,19 @@ def forward( window_size=window_size_per_step[i], ) else: - q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] + fa_forward_args_thd = [] + if qkv_format == "thd": + fa_forward_args_thd = [ + cu_seqlens_q, + cu_seqlens_kv_per_step[i], + max_seqlen_q, + max_seqlen_kv_, + ] fa_outputs = flash_attn_fwd( q_, k_, v_, - cu_seqlens_q, - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv_, + *fa_forward_args_thd, causal=causal, window_size=window_size_per_step[i], **fa_forward_kwargs, @@ -3620,9 +3621,9 @@ def forward( if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): if qkv_format == "bshd": - out[:, i - 1].copy_(out_per_step[i - 1].view(out[:, i - 1].shape)) + out[:, i - 1].copy_(out_per_step[i - 1]) elif qkv_format == "sbhd": - out[i - 1].copy_(out_per_step[i - 1].view(out[i - 1].shape)) + out[i - 1].copy_(out_per_step[i - 1]) torch.cuda.current_stream().wait_stream(cp_stream) @@ -3711,10 +3712,16 @@ def backward(ctx, dout): if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} if _use_flash_attn_3: - flash_attn_bwd = flash_attn_varlen_bwd_v3 + if ctx.qkv_format == "thd": + flash_attn_bwd = _flash_attn_varlen_bwd_v3 + else: + flash_attn_bwd = _flash_attn_bwd_v3 fa_backward_kwargs["deterministic"] = ctx.deterministic else: - flash_attn_bwd = flash_attn_varlen_bwd + if ctx.qkv_format == "thd": + flash_attn_bwd = _flash_attn_varlen_bwd + else: + flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p if _flash_attn_2_4_plus: fa_backward_kwargs["alibi_slopes"] = None @@ -3764,11 +3771,17 @@ def backward(ctx, dout): deterministic=ctx.deterministic, ) else: - batch_size = k_.shape[0] - q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ torch.empty_like(x) for x in [q_, k_, v_] ] + fa_backward_args_thd = [] + if ctx.qkv_format == "thd": + fa_backward_args_thd = [ + cu_seqlens_q, + cu_seqlens_kv_per_step[i], + ctx.max_seqlen_q, + max_seqlen_kv, + ] if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[i] flash_attn_bwd( @@ -3781,21 +3794,11 @@ def backward(ctx, dout): dq_per_step[i], dk_per_step[i], dv_per_step[i], - cu_seqlens_q, - cu_seqlens_kv_per_step[i], - ctx.max_seqlen_q, - max_seqlen_kv, + *fa_backward_args_thd, causal="causal" in ctx.attn_mask_type, window_size=window_size_per_step[i], **fa_backward_kwargs, ) - # [b*sq//2, np, hn] -> [b, sq//2, np, hn] - dq_per_step[i] = dq_per_step[i].view(dq[:, i].shape) - # [b*s_range, np, hn] -> [b, s_range, np, hn] - dk_per_step[i], dv_per_step[i] = [ - x.view(batch_size, -1, *x.shape[-2:]) - for x in [dk_per_step[i], dv_per_step[i]] - ] if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): @@ -3916,10 +3919,16 @@ def forward( if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} if _use_flash_attn_3: - flash_attn_fwd = flash_attn_varlen_fwd_v3 + if qkv_format == "thd": + flash_attn_fwd = _flash_attn_varlen_fwd_v3 + else: + flash_attn_fwd = _flash_attn_fwd_v3 fa_forward_kwargs["window_size"] = window_size else: - flash_attn_fwd = flash_attn_varlen_fwd + if qkv_format == "thd": + flash_attn_fwd = _flash_attn_varlen_fwd + else: + flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False if _flash_attn_2_3_plus: @@ -4025,24 +4034,25 @@ def forward( **fp8_meta_kwargs, ) else: - # [b, cp*s, np//cp, hn] -> [b*cp*s, np//cp, hn] - q, k, v = [x.view(-1, *x.shape[-2:]) for x in [q, k, v]] + fa_forward_args_thd = [] + if qkv_format == "thd": + fa_forward_args_thd = [ + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ] fa_outputs = flash_attn_fwd( q, k, v, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, + *fa_forward_args_thd, causal=causal, **fa_forward_kwargs, ) out, softmax_lse = fa_outputs[4], fa_outputs[5] rng_state = fa_outputs[7] if not _use_flash_attn_3 else None aux_ctx_tensors = [softmax_lse, rng_state] - # [b*cp*s, np//cp, hn] -> [b, cp*s, np//cp, hn] - out = out.view(batch_size, -1, *out.shape[-2:]) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False) out = flash_attn_a2a_communicate( @@ -4214,11 +4224,17 @@ def backward(ctx, dout): if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} if _use_flash_attn_3: - flash_attn_bwd = flash_attn_varlen_bwd_v3 + if ctx.qkv_format == "thd": + flash_attn_bwd = _flash_attn_varlen_bwd_v3 + else: + flash_attn_bwd = _flash_attn_bwd_v3 fa_backward_kwargs["window_size"] = ctx.window_size fa_backward_kwargs["deterministic"] = ctx.deterministic else: - flash_attn_bwd = flash_attn_varlen_bwd + if ctx.qkv_format == "thd": + flash_attn_bwd = _flash_attn_varlen_bwd + else: + flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p if _flash_attn_2_3_plus: fa_backward_kwargs["window_size"] = ctx.window_size @@ -4255,8 +4271,15 @@ def backward(ctx, dout): ) else: softmax_lse, rng_state = aux_ctx_tensors - out, dout = [x.view(-1, *x.shape[-2:]) for x in [out, dout]] dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] + fa_backward_args_thd = [] + if ctx.qkv_format == "thd": + fa_backward_args_thd = [ + cu_seqlens_q, + cu_seqlens_kv, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + ] if not _use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_state flash_attn_bwd( @@ -4269,14 +4292,10 @@ def backward(ctx, dout): dq, dk, dv, - cu_seqlens_q, - cu_seqlens_kv, - ctx.max_seqlen_q, - ctx.max_seqlen_kv, + *fa_backward_args_thd, causal=causal, **fa_backward_kwargs, ) - dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, False) dq, dk, dv = flash_attn_a2a_communicate( @@ -4400,18 +4419,17 @@ def attn_forward_func_with_cp( """Attention bias is only supported with FusedAttention and "causal" """ """or "no_mask" mask types!""" ) - assert ( + assert qkv_format != "thd" or ( cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None - ), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!" + ), "cu_seqlens_padded cannot be None with context parallelism + THD format!" sliding_window_attn = ( window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) ) - assert ( - not sliding_window_attn - or cp_comm_type == "a2a" - or (cp_comm_type == "all_gather" and not use_fused_attention) - ), "The context parallel running configs cannot support sliding window attetnion!" + assert not sliding_window_attn or cp_comm_type in [ + "a2a", + "all_gather", + ], "The context parallel running configs cannot support sliding window attetnion!" args = [ is_training, @@ -5419,8 +5437,8 @@ def forward( cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, + cu_seqlens_q if qkv_format == "thd" else None, + cu_seqlens_kv if qkv_format == "thd" else None, self.attention_dropout if self.training else 0.0, cp_group, cp_global_ranks, @@ -7215,7 +7233,7 @@ def forward( and cu_seqlens_kv is not None ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!" - if cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None: + if qkv_format == "thd" and (cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None): cu_seqlens_q_padded = cu_seqlens_q cu_seqlens_kv_padded = cu_seqlens_kv @@ -8151,10 +8169,10 @@ def forward( pad_between_seqs = ( cu_seqlens_q_padded is not None - and not torch.equal(cu_seqlens_q_padded, cu_seqlens_q) + and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1]) ) or ( cu_seqlens_kv_padded is not None - and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv) + and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1]) ) attention_params = AttentionParams( diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 50da91a1a1..f947930d23 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1537,10 +1537,10 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_ int batch, lse_seqlen; if (lse_packed) { batch = cu_seqlens.size(0) - 1; - lse_seqlen = total_tokens; + lse_seqlen = lse.size(1); NVTE_CHECK(lse.size(0) == num_heads); - NVTE_CHECK(lse.size(1) == lse_seqlen); + NVTE_CHECK(lse_seqlen >= total_tokens); NVTE_CHECK(lse_per_step.size(0) == num_heads); NVTE_CHECK(lse_per_step.size(1) == lse_seqlen / (only_second_half + 1)); } else {