From 257345a56d006bb24be890bfd813b1d1299807a8 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Thu, 20 Feb 2025 10:13:15 -0800 Subject: [PATCH] [PyTorch] Fix CP implementation with FP8 (#1483) * commit some debug code Signed-off-by: Xiaowei Ren * add more debug info Signed-off-by: Xiaowei Ren * debug code commit and typo fix Signed-off-by: Xiaowei Ren * a typo fix Signed-off-by: Xiaowei Ren * remove debug info Signed-off-by: Xiaowei Ren * do not return lse Signed-off-by: Xiaowei Ren * add amax_per_step for quantizers of CP Signed-off-by: Xiaowei Ren * fix FP8 + CP Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * bug fix Signed-off-by: Xiaowei Ren * bug fix Signed-off-by: Xiaowei Ren * dtype fix Signed-off-by: Xiaowei Ren * bug fix Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xiaowei Ren --- transformer_engine/pytorch/attention.py | 262 +++++++++++++++--------- transformer_engine/pytorch/fp8.py | 2 +- 2 files changed, 166 insertions(+), 98 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8584431dc2..d6b9894fc3 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1894,11 +1894,12 @@ def forward( fused_attn_backend = None qkv_dtype = q.dtype + amax_per_step = None + S_quantizer_per_step = [None for _ in range(cp_size)] + O_CP_quantizer_per_step = [None for _ in range(cp_size)] # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False is_output_fp8 = False - if fp8: - is_output_fp8 = fp8_meta["recipe"].fp8_mha ( QKV_quantizer, @@ -1919,28 +1920,30 @@ def forward( v, q.__class__ ), "q, k, and v must have the same type." is_input_fp8 = isinstance(q, Float8Tensor) - if not is_input_fp8: + is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha + if is_input_fp8: + QKV_quantizer = q._quantizer + q, k, v = q._data, k._data, v._data + else: q_f16, k_f16, v_f16 = q, k, v if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q = QKV_quantizer(q_f16) + q = QKV_quantizer(q_f16)._data if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - k, v = [QKV_quantizer(x) for x in [k_f16, v_f16]] - fp8_meta_kwargs = {} - fp8_meta_kwargs["s_quantizer"] = S_quantizer - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer # partial result quantizer + k, v = [QKV_quantizer(x)._data for x in [k_f16, v_f16]] + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + # partial result quantizer + for i in range(cp_size): + S_quantizer_per_step[i] = S_quantizer.copy() + S_quantizer_per_step[i].amax = amax_per_step[0][i] + O_CP_quantizer_per_step[i] = O_CP_quantizer.copy() + O_CP_quantizer_per_step[i].amax = amax_per_step[1][i] else: assert False, "FP8 is only supported with Fused Attention!" else: q_f16 = q if use_fused_attention: - fp8_meta_kwargs = {} fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] - if fp8: - q = q._data - k = k._data - v = v._data - if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True) @@ -2067,7 +2070,7 @@ def forward( kv_inputs[i % 2] = p2p_comm_buffers[i] else: # KV exchange is in BF16/FP16, cast received KV in each step - kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i]) + kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data if causal: if i == 0: if pad_between_seqs_q: @@ -2120,6 +2123,7 @@ def forward( if qkv_format in ["bshd", "sbhd"] else kv_inputs[i % 2][1] ) + fp8_meta_kwargs = {} if fp8: q_part = QKV_quantizer.create_tensor_from_data( q_part, fake_dtype=qkv_dtype, internal=True @@ -2130,6 +2134,8 @@ def forward( v_part = QKV_quantizer.create_tensor_from_data( v_part, fake_dtype=qkv_dtype, internal=True ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, @@ -2243,6 +2249,7 @@ def forward( if qkv_format in ["bshd", "sbhd"] else kv_inputs[i % 2][1] ) + fp8_meta_kwargs = {} if fp8: q_part = QKV_quantizer.create_tensor_from_data( q_part, fake_dtype=qkv_dtype, internal=True @@ -2253,6 +2260,8 @@ def forward( v_part = QKV_quantizer.create_tensor_from_data( v_part, fake_dtype=qkv_dtype, internal=True ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, @@ -2385,6 +2394,7 @@ def forward( if qkv_format in ["bshd", "sbhd"] else kv_inputs[i % 2][1] ) + fp8_meta_kwargs = {} if fp8: q_part = QKV_quantizer.create_tensor_from_data( q_part, fake_dtype=qkv_dtype, internal=True @@ -2395,6 +2405,8 @@ def forward( v_part = QKV_quantizer.create_tensor_from_data( v_part, fake_dtype=qkv_dtype, internal=True ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q // 2, @@ -2507,6 +2519,7 @@ def forward( if qkv_format in ["bshd", "sbhd"] else kv_inputs[i % 2][1] ) + fp8_meta_kwargs = {} if fp8: q_part = QKV_quantizer.create_tensor_from_data( q_part, fake_dtype=qkv_dtype, internal=True @@ -2517,6 +2530,8 @@ def forward( v_part = QKV_quantizer.create_tensor_from_data( v_part, fake_dtype=qkv_dtype, internal=True ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, @@ -2595,7 +2610,7 @@ def forward( with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if fp8: - out_per_step[i - 1] = out_per_step[i - 1].dequantize() + out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32) if i == 1: out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) @@ -2697,6 +2712,11 @@ def forward( elif not use_fused_attention: out = out.view(-1, *out.shape[-2:]) + if fp8 and use_fused_attention: + amax_cp_fwd = amax_per_step.amax(dim=1) + S_quantizer.amax = amax_cp_fwd[0] + O_CP_quantizer.amax = amax_cp_fwd[1] + out_fp8 = None out_f16 = out.to(qkv_dtype) @@ -2708,7 +2728,7 @@ def forward( if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_save, kv_save, out_save = q, kv, out_fp8._data elif fp8 and is_input_fp8: - q_save, kv_save, out_save = q, k, out_f16 + q_save, kv_save, out_save = q, kv, out_f16 else: q_f16 = q_f16.view(q.shape) q_save, kv_save, out_save = q_f16, kv, out_f16 @@ -2737,7 +2757,6 @@ def forward( ctx.dQKV_CP_quantizer = dQKV_CP_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer - ctx.qkv_dtype = qkv_dtype ctx.cp_group_a2a = cp_group_a2a ctx.cp_size_a2a = cp_size_a2a @@ -2778,10 +2797,8 @@ def backward(ctx, dout): recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) - saved_tensors = ctx.saved_tensors - q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( - restore_from_saved(ctx.tensor_objects, saved_tensors) + restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) ) cu_seqlens_q_per_step = other_tensors[:cp_size] cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2] @@ -2843,39 +2860,59 @@ def backward(ctx, dout): dout_dtype = dout.dtype fused_attn_backend = None fused_attn_dqkv_dtype = None + amax_per_step = None + dP_quantizer_per_step = [None for _ in range(cp_size)] + dQKV_CP_quantizer_per_step = [None for _ in range(cp_size)] if ctx.fp8: if ctx.use_fused_attention: fused_attn_backend = FusedAttnBackend["FP8"] - dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) - dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device) + dqkv_fp8_torch_dtype = get_fp8_torch_dtype( + ctx.fp8_meta["recipe"], fprop_tensor=False + ) + dq_fp8 = torch.empty( + (cp_size, *q.shape), dtype=dqkv_fp8_torch_dtype, device=q.device + ) + dkv_fp8 = torch.empty( + (cp_size, *kv.shape), dtype=dqkv_fp8_torch_dtype, device=kv.device + ) dkv_fp8_ = torch.empty_like(dkv_fp8) if ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - fused_attn_dqkv_dtype = dout._fp8_dtype - dout = dout._data + ctx.dO_quantizer = dout._quantizer else: dout = ctx.dO_quantizer(dout) - fused_attn_dqkv_dtype = dout._fp8_dtype - dout = dout._data + fused_attn_dqkv_dtype = dout._fp8_dtype + dout = dout._data p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer - fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer - fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_CP_quantizer + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + for i in range(cp_size): + dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() + dP_quantizer_per_step[i].amax = amax_per_step[0][i] + dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy() + dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i] else: assert False, "FP8 is only supported with Fused Attention!" else: - if ctx.fp8_meta is not None and ctx.is_input_fp8: - q = ctx.QKV_quantizer.create_tensor_from_data( - q, fake_dtype=ctx.qkv_dtype, internal=True - ) - kv = ctx.QKV_quantizer.create_tensor_from_data( - kv, fake_dtype=ctx.qkv_dtype, internal=True - ) - q, kv = q.dequantize(), kv.dequantize() - if cp_size_a2a == 1: - dout = dout.dequantize() + if ctx.fp8_meta is not None: + if ctx.is_input_fp8: + q = ctx.QKV_quantizer.create_tensor_from_data( + q, fake_dtype=ctx.qkv_dtype, internal=True + ) + kv = ctx.QKV_quantizer.create_tensor_from_data( + kv, fake_dtype=ctx.qkv_dtype, internal=True + ) + q = q.dequantize(dtype=ctx.qkv_dtype) + kv = kv.dequantize(dtype=ctx.qkv_dtype) + if ctx.is_output_fp8: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + if cp_size_a2a == 1: + dout = dout.dequantize(dtype=dout_dtype) + else: + ctx.dO_quantizer = dout._quantizer + dout = dout._data dq = torch.empty_like(q) p2p_comm_buffers = [ torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), @@ -2902,9 +2939,10 @@ def backward(ctx, dout): True, ) if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: - dout = ctx.dO_quantizer.create_tensor_from_data(data=dout, internal=True) - dout = dout.dequantize() - dout = dout._data + dout = ctx.dO_quantizer.create_tensor_from_data( + dout, fake_dtype=dout_dtype, internal=True + ) + dout = dout.dequantize(dtype=dout_dtype) out = out.view(*q.shape) dout = dout.view(*q.shape) @@ -3020,8 +3058,10 @@ def backward(ctx, dout): out_part, fake_dtype=ctx.qkv_dtype, internal=True ) dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=ctx.qkv_dtype, internal=True + dout_part, fake_dtype=dout_dtype, internal=True ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -3133,8 +3173,10 @@ def backward(ctx, dout): out_part, fake_dtype=ctx.qkv_dtype, internal=True ) dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=ctx.qkv_dtype, internal=True + dout_part, fake_dtype=dout_dtype, internal=True ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, @@ -3250,8 +3292,10 @@ def backward(ctx, dout): out_part, fake_dtype=ctx.qkv_dtype, internal=True ) dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=ctx.qkv_dtype, internal=True + dout_part, fake_dtype=dout_dtype, internal=True ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, @@ -3282,7 +3326,6 @@ def backward(ctx, dout): dq_ = dq_._data dk_ = dk_._data dv_ = dv_._data - else: dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) @@ -3333,20 +3376,22 @@ def backward(ctx, dout): if ctx.fp8: q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype + q_part, fake_dtype=ctx.qkv_dtype, internal=True ) k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype + k_part, fake_dtype=ctx.qkv_dtype, internal=True ) v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype + v_part, fake_dtype=ctx.qkv_dtype, internal=True ) out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype + out_part, fake_dtype=ctx.qkv_dtype, internal=True ) dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=ctx.qkv_dtype + dout_part, fake_dtype=dout_dtype, internal=True ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -3555,13 +3600,20 @@ def backward(ctx, dout): dkv.add_(dkv_) if ctx.fp8 and ctx.use_fused_attention: + amax_cp_bwd = amax_per_step.amax(dim=1) + ctx.dP_quantizer.amax = amax_cp_bwd[0] + ctx.dQKV_CP_quantizer.amax = amax_cp_bwd[1] if ctx.qkv_format in ["bshd", "sbhd"]: # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) - dq = ctx.dQKV_quantizer.create_tensor_from_data(dq_fp8) - dkv = ctx.dQKV_quantizer.create_tensor_from_data(dkv_fp8) - dq, dkv = [x.dequantize() for x in [dq, dkv]] + dq = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dq_fp8, fake_dtype=torch.float32, internal=True + ) + dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dkv_fp8, fake_dtype=torch.float32, internal=True + ) + dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]] dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] if causal: @@ -3606,9 +3658,9 @@ def backward(ctx, dout): attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) # converting torch.uint8 to float8tensor if ctx.fp8 and ctx.is_input_fp8: - dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, ctx.qkv_dtype) - dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, ctx.qkv_dtype) - dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, ctx.qkv_dtype) + dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype) + dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype) + dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype) nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward") return ( @@ -4227,21 +4279,20 @@ def forward( # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False is_output_fp8 = False - if fp8: - is_output_fp8 = fp8_meta["recipe"].fp8_mha QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) ) if fp8: if use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] assert isinstance(k, q.__class__) and isinstance( v, q.__class__ ), "q, k, and v must have the same type." is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha if is_input_fp8: + QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v q, k, v = q_fp8._data, k_fp8._data, v_fp8._data elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")): @@ -4350,31 +4401,24 @@ def forward( out = out_fp8._data else: out_fp8 = O_quantizer.create_tensor_from_data( - out, fake_dtype=qkv_dtype, internal=False + out, fake_dtype=qkv_dtype, internal=True ) - out_f16 = out_fp8.dequantize() + out_f16 = out_fp8.dequantize(dtype=qkv_dtype) out_ret = out_f16 else: out_ret = out - if fp8: - if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_save, k_save, v_save, out_save = q, k, v, out - elif is_input_fp8: - q_fp8 = QKV_quantizer.create_tensor_from_data( - q, fake_dtype=qkv_dtype, internal=False - ) - k_fp8 = QKV_quantizer.create_tensor_from_data( - k, fake_dtype=qkv_dtype, internal=False - ) - v_fp8 = QKV_quantizer.create_tensor_from_data( - v, fake_dtype=qkv_dtype, internal=False - ) - q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out - else: - q_save, k_save, v_save, out_save = q_f16, k_f16, v_f16, out_f16 - else: + if not fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_save, k_save, v_save, out_save = q, k, v, out + else: + if is_input_fp8: + q_save, k_save, v_save = q, k, v + else: + q_save, k_save, v_save = q_f16, k_f16, v_f16 + if is_output_fp8: + out_save = out + else: + out_save = out_f16 tensors_to_save, tensor_objects = prepare_for_saving( q_save, @@ -4397,7 +4441,6 @@ def forward( ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer - ctx.qkv_dtype = qkv_dtype ctx.batch_size = batch_size ctx.cp_group = cp_group @@ -4436,27 +4479,24 @@ def backward(ctx, dout): cu_seqlens_kv_padded, *aux_ctx_tensors, ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - dout_dtype = dout.dtype qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format causal = "causal" in ctx.attn_mask_type seq_dim = ctx.qkv_format.index("s") + dout_dtype = dout.dtype fused_attn_backend = None fused_attn_dqkv_dtype = None if ctx.fp8: - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - fused_attn_dqkv_dtype = fp8_dtype_backward - if ctx.use_fused_attention: fused_attn_backend = FusedAttnBackend["FP8"] if ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - dout_fp8 = dout - dout = dout_fp8._data + ctx.dO_quantizer = dout._quantizer else: - dout_f16 = dout - dout = ctx.dO_quantizer(dout_f16)._data + dout = ctx.dO_quantizer(dout) + fused_attn_dqkv_dtype = dout._fp8_dtype + dout = dout._data fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer @@ -4465,12 +4505,25 @@ def backward(ctx, dout): else: assert False, "FP8 is only supported with Fused Attention!" else: - if ctx.fp8_meta is not None and ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - q, k, v, out, dout = [x.dequantize() for x in [q, k, v, out, dout]] + if ctx.fp8_meta is not None: + if ctx.is_output_fp8: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + ctx.dO_quantizer = dout._quantizer + dout = dout._data + if ctx.is_input_fp8: + q = ctx.QKV_quantizer.create_tensor_from_data( + q, fake_dtype=ctx.qkv_dtype, internal=True + ) + k = ctx.QKV_quantizer.create_tensor_from_data( + k, fake_dtype=ctx.qkv_dtype, internal=True + ) + v = ctx.QKV_quantizer.create_tensor_from_data( + v, fake_dtype=ctx.qkv_dtype, internal=True + ) + q, k, v = [x.dequantize(dtype=ctx.qkv_dtype) for x in [q, k, v]] if ctx.use_fused_attention: fp8_meta_kwargs = {} - fused_attn_dqkv_dtype = TE_DType[dout.dtype] + fused_attn_dqkv_dtype = TE_DType[dout_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] if not ctx.use_fused_attention: @@ -4481,6 +4534,15 @@ def backward(ctx, dout): out, dout = flash_attn_a2a_communicate( [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True ) + if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: + out = ctx.O_quantizer.create_tensor_from_data( + out, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout = ctx.dO_quantizer.create_tensor_from_data( + dout, fake_dtype=dout_dtype, internal=True + ) + out = out.dequantize(dtype=ctx.qkv_dtype) + dout = dout.dequantize(dtype=dout_dtype) flash_attn_bwd = None if not ctx.use_fused_attention: @@ -4531,7 +4593,7 @@ def backward(ctx, dout): out_part, fake_dtype=ctx.qkv_dtype, internal=True ) dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=ctx.qkv_dtype, internal=True + dout_part, fake_dtype=dout_dtype, internal=True ) dq, dk, dv, _ = fused_attn_bwd( @@ -4602,11 +4664,17 @@ def backward(ctx, dout): dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] if ctx.fp8: - dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype) - dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype) - dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype) + dq = ctx.dQKV_quantizer.create_tensor_from_data( + dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 + ) + dk = ctx.dQKV_quantizer.create_tensor_from_data( + dk, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 + ) + dv = ctx.dQKV_quantizer.create_tensor_from_data( + dv, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 + ) if not ctx.is_input_fp8: - dq, dk, dv = [x.dequantize() for x in [dq, dk, dv]] + dq, dk, dv = [x.dequantize(dtype=dout_dtype) for x in [dq, dk, dv]] nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") return ( diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 254bcf12e1..f788368112 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -56,7 +56,7 @@ def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch. fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor ): return torch.float8_e4m3fn - return torch.float8_e5m2fn + return torch.float8_e5m2 def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: