Skip to content

Commit b459ccc

Browse files
Oleg-Goncharovptrendxtimmoon10ksivaman
authored
[PyTorch] Adjusted the logic of MHA and DPA to enable speculative decoding (NVIDIA#668)
* Modified MHA and DPA logic to use causal softmax and FA for inference Signed-off-by: Oleg Goncharov <[email protected]> * Adjusted unfused attention and softmax logic for inference Signed-off-by: Oleg Goncharov <[email protected]> * Cleaned up the code per pylint Signed-off-by: Oleg Goncharov <[email protected]> * Added test cases to evaluate numerics of incremental decoding Signed-off-by: Oleg Goncharov <[email protected]> * Apply suggestions from code review Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Oleg Goncharov <[email protected]> * Apply suggestions from code review [sequence start-end] Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Oleg Goncharov <[email protected]> * Apply suggestions from code review [inference_params offset update]] Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Oleg Goncharov <[email protected]> * Fixed bug in KV-cache indices and updated test suite Signed-off-by: Oleg Goncharov <[email protected]> * Added inference_params description and applied suggestions from the code review Signed-off-by: Oleg Goncharov <[email protected]> * Adjusted absolute tolerances in numerics tests Signed-off-by: Oleg Goncharov <[email protected]> * Cleaned up the files per pylint Signed-off-by: Oleg Goncharov <[email protected]> --------- Signed-off-by: Oleg Goncharov <[email protected]> Signed-off-by: Oleg Goncharov <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
1 parent 728e335 commit b459ccc

File tree

3 files changed

+243
-85
lines changed

3 files changed

+243
-85
lines changed

tests/pytorch/test_numerics.py

+116-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
)
2323
from transformer_engine.pytorch import (
2424
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear,
25-
MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm
25+
MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm, InferenceParams
2626
)
2727
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
2828
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
@@ -1397,3 +1397,118 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
13971397
y_bshd = block_bshd(x_bshd)
13981398

13991399
assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()])
1400+
1401+
1402+
model_configs_inference = {
1403+
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
1404+
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 16),
1405+
}
1406+
backends_inference = ["FlashAttention", "UnfusedAttention"]
1407+
module_inference = ["TransformerLayer", "MultiheadAttention"]
1408+
input_formats_inference = ["sbhd", "bshd"]
1409+
1410+
@pytest.mark.parametrize("dtype", param_types)
1411+
@pytest.mark.parametrize("bs", batch_sizes)
1412+
@pytest.mark.parametrize("model_key", model_configs_inference.keys())
1413+
@pytest.mark.parametrize("use_RoPE", all_boolean)
1414+
@pytest.mark.parametrize("input_format", input_formats_inference)
1415+
@pytest.mark.parametrize("module", module_inference)
1416+
@pytest.mark.parametrize("backend", backends_inference)
1417+
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend):
1418+
os.environ["NVTE_FLASH_ATTN"] = "0"
1419+
os.environ["NVTE_FUSED_ATTN"] = "0"
1420+
1421+
if backend == "FlashAttention":
1422+
os.environ["NVTE_FLASH_ATTN"] = "1"
1423+
elif backend == "FusedAttention":
1424+
os.environ["NVTE_FUSED_ATTN"] = "1"
1425+
1426+
config = model_configs_inference[model_key]
1427+
1428+
S = config.seq_len
1429+
B = bs
1430+
H = config.num_attention_heads
1431+
D = config.hidden_size
1432+
head_size = config.embed
1433+
layer_number = 1
1434+
1435+
# Limits the max size of KV-cache
1436+
B_max = B
1437+
S_max = S + 2
1438+
1439+
if module == "TransformerLayer":
1440+
model = (
1441+
TransformerLayer(
1442+
hidden_size=D,
1443+
ffn_hidden_size= 4 * D,
1444+
num_attention_heads=H,
1445+
attn_input_format=input_format,
1446+
layer_number=layer_number,
1447+
attention_dropout = 0.0
1448+
)
1449+
.to(dtype=dtype)
1450+
.cuda()
1451+
.eval()
1452+
)
1453+
else:
1454+
model = (
1455+
MultiheadAttention(
1456+
hidden_size=D,
1457+
num_attention_heads=H,
1458+
qkv_format=input_format,
1459+
layer_number=layer_number,
1460+
attention_dropout = 0.0
1461+
)
1462+
.to(dtype=dtype)
1463+
.cuda()
1464+
.eval()
1465+
)
1466+
1467+
inference_params = InferenceParams(max_batch_size=B_max, max_sequence_length=S_max)
1468+
rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")
1469+
1470+
input = torch.randn((S, B, D), dtype=dtype, device="cuda")
1471+
if input_format == "bshd":
1472+
input = input.transpose(0, 1).contiguous()
1473+
1474+
incremental_output = torch.zeros_like(input)
1475+
1476+
# Generate output for the entire sequence
1477+
full_output = model(
1478+
hidden_states=input,
1479+
rotary_pos_emb=rotary_freqs if use_RoPE else None)
1480+
1481+
# Incrementaly generate outputs using KV-cache
1482+
for i in range(S):
1483+
if input_format == "sbhd":
1484+
incremental_input = input[i].view(1,B,D)
1485+
else:
1486+
incremental_input = input[:, i, :].view(B,1,D)
1487+
1488+
line_output = model(
1489+
hidden_states=incremental_input,
1490+
inference_params=inference_params,
1491+
rotary_pos_emb=rotary_freqs if use_RoPE else None)
1492+
1493+
inference_params.sequence_len_offset += 1
1494+
1495+
if input_format == "sbhd":
1496+
incremental_output[i] = line_output.view(B,D)
1497+
else:
1498+
incremental_output[:, i, :] = line_output.view(B,D)
1499+
1500+
if module == "TransformerLayer":
1501+
atol = {
1502+
torch.float32 : 5e-3,
1503+
torch.half : 5e-3,
1504+
torch.bfloat16: 5e-2,
1505+
}
1506+
else:
1507+
atol = {
1508+
torch.float32 : 1e-3,
1509+
torch.half : 1e-3,
1510+
torch.bfloat16: 1e-2,
1511+
}
1512+
1513+
# Check if the fully generated output matches the one generated incrementally
1514+
assert_allclose(full_output, incremental_output, atol[dtype])

transformer_engine/pytorch/attention.py

+86-49
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@
8484

8585
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
8686

87-
8887
class InferenceParams: # pylint: disable=too-few-public-methods
8988
"""
9089
Inference parameters that are passed to the main model in order
@@ -1180,7 +1179,7 @@ def apply_rotary_pos_emb(
11801179
Parameters
11811180
----------
11821181
t: torch.Tensor
1183-
Input tensor of shape `[s, b, h, d]`, `[s, b, h, d]` or `[t, h, d]`, on which
1182+
Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
11841183
rotary positional embedding will be applied.
11851184
freqs: torch.Tensor
11861185
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
@@ -2523,6 +2522,7 @@ def forward(
25232522
core_attention_bias: Optional[torch.Tensor] = None,
25242523
alibi_slopes: Optional[torch.Tensor] = None,
25252524
fast_zero_fill: bool = True,
2525+
inference_params: Optional[InferenceParams] = None,
25262526
) -> torch.Tensor:
25272527
"""
25282528
Dot Product Attention Layer.
@@ -2616,6 +2616,16 @@ def forward(
26162616
to the attention score of query i and key j.
26172617
fast_zero_fill: bool, default = `True`
26182618
Whether to use the fast path to set output tensors to 0 or not.
2619+
inference_params: Optional[InferenceParams], default = `None`
2620+
Optimizes execution performance during inference by caching Keys and Values of the
2621+
current decoding iteration. These cached values are appended to the K and V values
2622+
computed in previous iterations, eliminating the need to recalculate them for the
2623+
entire sequence.
2624+
Initialization of `inference_params` is required prior to use to ensure sufficient
2625+
memory allocation.
2626+
Adjustments of the sequence_len_offset should be done after a complete forward pass.
2627+
If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
2628+
Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
26192629
"""
26202630

26212631
assert (
@@ -2643,6 +2653,39 @@ def forward(
26432653
if qkv_format is None:
26442654
qkv_format = self.qkv_format
26452655

2656+
if inference_params is not None:
2657+
assert self.layer_number is not None, "Layer number must be set!"
2658+
2659+
if qkv_format == "bshd":
2660+
key_layer = key_layer.transpose(0, 1)
2661+
value_layer = value_layer.transpose(0, 1)
2662+
2663+
(inference_key_memory, inference_value_memory,
2664+
) = inference_params.key_value_memory_dict[self.layer_number]
2665+
2666+
batch_start = inference_params.batch_size_offset
2667+
batch_end = batch_start + key_layer.size(1)
2668+
assert batch_end <= inference_key_memory.size(1)
2669+
2670+
sequence_start = inference_params.sequence_len_offset
2671+
sequence_end = sequence_start + key_layer.size(0)
2672+
assert sequence_end <= inference_key_memory.size(0)
2673+
2674+
# Copy keys and values into KV-cache
2675+
inference_key_memory[
2676+
sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer
2677+
inference_value_memory[
2678+
sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer
2679+
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
2680+
value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
2681+
2682+
if qkv_format == "bshd":
2683+
key_layer = key_layer.transpose(0, 1)
2684+
value_layer = value_layer.transpose(0, 1)
2685+
2686+
key_layer = key_layer.contiguous()
2687+
value_layer = value_layer.contiguous()
2688+
26462689
assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
26472690
and value_layer.shape[-2] == self.num_gqa_groups_per_partition
26482691
), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
@@ -2721,12 +2764,15 @@ def forward(
27212764
use_flash_attention = False
27222765

27232766
# Filter: cross attention + causal mask.
2724-
if (_flash_attn_2_1_plus
2767+
# (in training mode)
2768+
if (inference_params is None
2769+
and _flash_attn_2_1_plus
27252770
and "causal" in attn_mask_type
2726-
and max_seqlen_q != max_seqlen_kv):
2771+
and max_seqlen_q != max_seqlen_kv
2772+
):
27272773
warnings.warn(
2728-
"Disabling the use of FlashAttention since version 2.1+ has changed its behavior "
2729-
"for causal mask in cross attention. See "
2774+
"In training mode, disable the use of FlashAttention since version 2.1+ has "
2775+
"changed its behavior for causal mask in cross attention. See "
27302776
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
27312777
)
27322778
use_flash_attention = False
@@ -2753,7 +2799,11 @@ def forward(
27532799
if attn_mask_type == "arbitrary":
27542800
use_flash_attention = False
27552801
use_fused_attention = False
2756-
if "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
2802+
2803+
if (inference_params is None
2804+
and "causal" in attn_mask_type
2805+
and max_seqlen_q != max_seqlen_kv
2806+
):
27572807
use_unfused_attention = False
27582808

27592809
# Filter: bias.
@@ -3446,12 +3496,12 @@ def forward(
34463496
), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
34473497

34483498
# =================================================
3449-
# Pre-allocate memory for key-values for inference.
3499+
# Pre-allocate memory for key-values for inference
34503500
# =================================================
34513501

34523502
if inference_params and self.layer_number is not None:
34533503
if self.layer_number not in inference_params.key_value_memory_dict:
3454-
inf_max_seq_len = inference_params.max_sequence_len
3504+
inf_max_seq_len = inference_params.max_sequence_length
34553505
inf_max_batch_size = inference_params.max_batch_size
34563506
inference_key_memory = self._allocate_memory(
34573507
inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
@@ -3469,9 +3519,9 @@ def forward(
34693519
inference_value_memory,
34703520
) = inference_params.key_value_memory_dict[self.layer_number]
34713521

3472-
# =====================
3522+
# ======================
34733523
# Query, Key, and Value
3474-
# =====================
3524+
# ======================
34753525

34763526
if self.attention_type == "self":
34773527
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
@@ -3593,51 +3643,37 @@ def forward(
35933643
)
35943644
query_layer = query_layer.view(*new_tensor_shape)
35953645

3596-
# ==================================
3597-
# Adjust key and value for inference
3598-
# ==================================
3646+
# ======================================================
3647+
# Apply relative positional encoding (rotary embedding)
3648+
# ======================================================
35993649

3600-
# duplicate the pos_emb for self attention
36013650
if rotary_pos_emb is not None:
3651+
# duplicate the pos_emb for self attention
36023652
if not isinstance(rotary_pos_emb, tuple):
36033653
rotary_pos_emb = ((rotary_pos_emb,) * 2)
36043654

3605-
if inference_params and self.layer_number is not None:
3606-
batch_start = inference_params.batch_size_offset
3607-
batch_end = batch_start + key_layer.size(1)
3608-
assert batch_end <= inference_key_memory.size(1)
3609-
sequence_start = inference_params.sequence_len_offset
3610-
sequence_end = sequence_start + key_layer.size(0)
3611-
assert sequence_end <= inference_key_memory.size(0)
3612-
# Copy key and values.
3613-
inference_key_memory[
3614-
sequence_start:sequence_end, batch_start:batch_end, ...
3615-
] = key_layer
3616-
inference_value_memory[
3617-
sequence_start:sequence_end, batch_start:batch_end, ...
3618-
] = value_layer
3619-
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
3620-
value_layer = inference_value_memory[
3621-
:sequence_end, batch_start:batch_end, ...
3622-
]
3623-
3624-
# adjust the key rotary positional embedding
3625-
if rotary_pos_emb is not None:
3626-
q_pos_emb, k_pos_emb = rotary_pos_emb
3627-
q_pos_emb = q_pos_emb[sequence_start:sequence_end, :, :, :]
3628-
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
3629-
rotary_pos_emb = (q_pos_emb, k_pos_emb)
3630-
3631-
# ==================================
3632-
# core attention computation
3633-
# ==================================
3634-
3635-
# apply relative positional encoding (rotary embedding)
3636-
if rotary_pos_emb is not None:
36373655
q_pos_emb, k_pos_emb = rotary_pos_emb
3656+
3657+
# adjust key and value for inference
3658+
if inference_params is not None:
3659+
if self.qkv_format == "sbhd":
3660+
sequence_length = key_layer.size(0)
3661+
elif self.qkv_format == "bshd":
3662+
sequence_length = key_layer.size(1)
3663+
3664+
sequence_start = inference_params.sequence_len_offset
3665+
sequence_end = sequence_start + sequence_length
3666+
3667+
q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
3668+
k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]
3669+
36383670
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
36393671
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)
36403672

3673+
# ===========================
3674+
# Core attention computation
3675+
# ===========================
3676+
36413677
context_layer = self.core_attention(
36423678
query_layer,
36433679
key_layer,
@@ -3653,11 +3689,12 @@ def forward(
36533689
core_attention_bias=core_attention_bias,
36543690
alibi_slopes=alibi_slopes,
36553691
fast_zero_fill=fast_zero_fill,
3692+
inference_params=inference_params,
36563693
)
36573694

3658-
# =================
3695+
# ===================
36593696
# Output. [sq, b, h]
3660-
# =================
3697+
# ===================
36613698

36623699
projection_output = self.proj(
36633700
context_layer, is_first_microbatch=is_first_microbatch

0 commit comments

Comments
 (0)