84
84
85
85
__all__ = ["DotProductAttention" , "InferenceParams" , "MultiheadAttention" ]
86
86
87
-
88
87
class InferenceParams : # pylint: disable=too-few-public-methods
89
88
"""
90
89
Inference parameters that are passed to the main model in order
@@ -1180,7 +1179,7 @@ def apply_rotary_pos_emb(
1180
1179
Parameters
1181
1180
----------
1182
1181
t: torch.Tensor
1183
- Input tensor of shape `[s, b, h, d]`, `[s, b , h, d]` or `[t, h, d]`, on which
1182
+ Input tensor of shape `[s, b, h, d]`, `[b, s , h, d]` or `[t, h, d]`, on which
1184
1183
rotary positional embedding will be applied.
1185
1184
freqs: torch.Tensor
1186
1185
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
@@ -2523,6 +2522,7 @@ def forward(
2523
2522
core_attention_bias : Optional [torch .Tensor ] = None ,
2524
2523
alibi_slopes : Optional [torch .Tensor ] = None ,
2525
2524
fast_zero_fill : bool = True ,
2525
+ inference_params : Optional [InferenceParams ] = None ,
2526
2526
) -> torch .Tensor :
2527
2527
"""
2528
2528
Dot Product Attention Layer.
@@ -2616,6 +2616,16 @@ def forward(
2616
2616
to the attention score of query i and key j.
2617
2617
fast_zero_fill: bool, default = `True`
2618
2618
Whether to use the fast path to set output tensors to 0 or not.
2619
+ inference_params: Optional[InferenceParams], default = `None`
2620
+ Optimizes execution performance during inference by caching Keys and Values of the
2621
+ current decoding iteration. These cached values are appended to the K and V values
2622
+ computed in previous iterations, eliminating the need to recalculate them for the
2623
+ entire sequence.
2624
+ Initialization of `inference_params` is required prior to use to ensure sufficient
2625
+ memory allocation.
2626
+ Adjustments of the sequence_len_offset should be done after a complete forward pass.
2627
+ If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
2628
+ Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
2619
2629
"""
2620
2630
2621
2631
assert (
@@ -2643,6 +2653,39 @@ def forward(
2643
2653
if qkv_format is None :
2644
2654
qkv_format = self .qkv_format
2645
2655
2656
+ if inference_params is not None :
2657
+ assert self .layer_number is not None , "Layer number must be set!"
2658
+
2659
+ if qkv_format == "bshd" :
2660
+ key_layer = key_layer .transpose (0 , 1 )
2661
+ value_layer = value_layer .transpose (0 , 1 )
2662
+
2663
+ (inference_key_memory , inference_value_memory ,
2664
+ ) = inference_params .key_value_memory_dict [self .layer_number ]
2665
+
2666
+ batch_start = inference_params .batch_size_offset
2667
+ batch_end = batch_start + key_layer .size (1 )
2668
+ assert batch_end <= inference_key_memory .size (1 )
2669
+
2670
+ sequence_start = inference_params .sequence_len_offset
2671
+ sequence_end = sequence_start + key_layer .size (0 )
2672
+ assert sequence_end <= inference_key_memory .size (0 )
2673
+
2674
+ # Copy keys and values into KV-cache
2675
+ inference_key_memory [
2676
+ sequence_start :sequence_end , batch_start :batch_end , ...] = key_layer
2677
+ inference_value_memory [
2678
+ sequence_start :sequence_end , batch_start :batch_end , ...] = value_layer
2679
+ key_layer = inference_key_memory [:sequence_end , batch_start :batch_end , ...]
2680
+ value_layer = inference_value_memory [:sequence_end , batch_start :batch_end , ...]
2681
+
2682
+ if qkv_format == "bshd" :
2683
+ key_layer = key_layer .transpose (0 , 1 )
2684
+ value_layer = value_layer .transpose (0 , 1 )
2685
+
2686
+ key_layer = key_layer .contiguous ()
2687
+ value_layer = value_layer .contiguous ()
2688
+
2646
2689
assert (key_layer .shape [- 2 ] == self .num_gqa_groups_per_partition
2647
2690
and value_layer .shape [- 2 ] == self .num_gqa_groups_per_partition
2648
2691
), f"Keys and values must have num_gqa_group = { self .num_gqa_groups } heads!"
@@ -2721,12 +2764,15 @@ def forward(
2721
2764
use_flash_attention = False
2722
2765
2723
2766
# Filter: cross attention + causal mask.
2724
- if (_flash_attn_2_1_plus
2767
+ # (in training mode)
2768
+ if (inference_params is None
2769
+ and _flash_attn_2_1_plus
2725
2770
and "causal" in attn_mask_type
2726
- and max_seqlen_q != max_seqlen_kv ):
2771
+ and max_seqlen_q != max_seqlen_kv
2772
+ ):
2727
2773
warnings .warn (
2728
- "Disabling the use of FlashAttention since version 2.1+ has changed its behavior "
2729
- "for causal mask in cross attention. See "
2774
+ "In training mode, disable the use of FlashAttention since version 2.1+ has "
2775
+ "changed its behavior for causal mask in cross attention. See "
2730
2776
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
2731
2777
)
2732
2778
use_flash_attention = False
@@ -2753,7 +2799,11 @@ def forward(
2753
2799
if attn_mask_type == "arbitrary" :
2754
2800
use_flash_attention = False
2755
2801
use_fused_attention = False
2756
- if "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv :
2802
+
2803
+ if (inference_params is None
2804
+ and "causal" in attn_mask_type
2805
+ and max_seqlen_q != max_seqlen_kv
2806
+ ):
2757
2807
use_unfused_attention = False
2758
2808
2759
2809
# Filter: bias.
@@ -3446,12 +3496,12 @@ def forward(
3446
3496
), f"core_attention_bias_type { core_attention_bias_type } is not supported!"
3447
3497
3448
3498
# =================================================
3449
- # Pre-allocate memory for key-values for inference.
3499
+ # Pre-allocate memory for key-values for inference
3450
3500
# =================================================
3451
3501
3452
3502
if inference_params and self .layer_number is not None :
3453
3503
if self .layer_number not in inference_params .key_value_memory_dict :
3454
- inf_max_seq_len = inference_params .max_sequence_len
3504
+ inf_max_seq_len = inference_params .max_sequence_length
3455
3505
inf_max_batch_size = inference_params .max_batch_size
3456
3506
inference_key_memory = self ._allocate_memory (
3457
3507
inf_max_seq_len , inf_max_batch_size , hidden_states .dtype
@@ -3469,9 +3519,9 @@ def forward(
3469
3519
inference_value_memory ,
3470
3520
) = inference_params .key_value_memory_dict [self .layer_number ]
3471
3521
3472
- # =====================
3522
+ # ======================
3473
3523
# Query, Key, and Value
3474
- # =====================
3524
+ # ======================
3475
3525
3476
3526
if self .attention_type == "self" :
3477
3527
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
@@ -3593,51 +3643,37 @@ def forward(
3593
3643
)
3594
3644
query_layer = query_layer .view (* new_tensor_shape )
3595
3645
3596
- # ==================================
3597
- # Adjust key and value for inference
3598
- # ==================================
3646
+ # ======================================================
3647
+ # Apply relative positional encoding (rotary embedding)
3648
+ # ======================================================
3599
3649
3600
- # duplicate the pos_emb for self attention
3601
3650
if rotary_pos_emb is not None :
3651
+ # duplicate the pos_emb for self attention
3602
3652
if not isinstance (rotary_pos_emb , tuple ):
3603
3653
rotary_pos_emb = ((rotary_pos_emb ,) * 2 )
3604
3654
3605
- if inference_params and self .layer_number is not None :
3606
- batch_start = inference_params .batch_size_offset
3607
- batch_end = batch_start + key_layer .size (1 )
3608
- assert batch_end <= inference_key_memory .size (1 )
3609
- sequence_start = inference_params .sequence_len_offset
3610
- sequence_end = sequence_start + key_layer .size (0 )
3611
- assert sequence_end <= inference_key_memory .size (0 )
3612
- # Copy key and values.
3613
- inference_key_memory [
3614
- sequence_start :sequence_end , batch_start :batch_end , ...
3615
- ] = key_layer
3616
- inference_value_memory [
3617
- sequence_start :sequence_end , batch_start :batch_end , ...
3618
- ] = value_layer
3619
- key_layer = inference_key_memory [:sequence_end , batch_start :batch_end , ...]
3620
- value_layer = inference_value_memory [
3621
- :sequence_end , batch_start :batch_end , ...
3622
- ]
3623
-
3624
- # adjust the key rotary positional embedding
3625
- if rotary_pos_emb is not None :
3626
- q_pos_emb , k_pos_emb = rotary_pos_emb
3627
- q_pos_emb = q_pos_emb [sequence_start :sequence_end , :, :, :]
3628
- k_pos_emb = k_pos_emb [:sequence_end , :, :, :]
3629
- rotary_pos_emb = (q_pos_emb , k_pos_emb )
3630
-
3631
- # ==================================
3632
- # core attention computation
3633
- # ==================================
3634
-
3635
- # apply relative positional encoding (rotary embedding)
3636
- if rotary_pos_emb is not None :
3637
3655
q_pos_emb , k_pos_emb = rotary_pos_emb
3656
+
3657
+ # adjust key and value for inference
3658
+ if inference_params is not None :
3659
+ if self .qkv_format == "sbhd" :
3660
+ sequence_length = key_layer .size (0 )
3661
+ elif self .qkv_format == "bshd" :
3662
+ sequence_length = key_layer .size (1 )
3663
+
3664
+ sequence_start = inference_params .sequence_len_offset
3665
+ sequence_end = sequence_start + sequence_length
3666
+
3667
+ q_pos_emb = q_pos_emb [sequence_start :sequence_end , ...]
3668
+ k_pos_emb = k_pos_emb [sequence_start :sequence_end , ...]
3669
+
3638
3670
query_layer = apply_rotary_pos_emb (query_layer , q_pos_emb , self .qkv_format , fused = True )
3639
3671
key_layer = apply_rotary_pos_emb (key_layer , k_pos_emb , self .qkv_format , fused = True )
3640
3672
3673
+ # ===========================
3674
+ # Core attention computation
3675
+ # ===========================
3676
+
3641
3677
context_layer = self .core_attention (
3642
3678
query_layer ,
3643
3679
key_layer ,
@@ -3653,11 +3689,12 @@ def forward(
3653
3689
core_attention_bias = core_attention_bias ,
3654
3690
alibi_slopes = alibi_slopes ,
3655
3691
fast_zero_fill = fast_zero_fill ,
3692
+ inference_params = inference_params ,
3656
3693
)
3657
3694
3658
- # =================
3695
+ # ===================
3659
3696
# Output. [sq, b, h]
3660
- # =================
3697
+ # ===================
3661
3698
3662
3699
projection_output = self .proj (
3663
3700
context_layer , is_first_microbatch = is_first_microbatch
0 commit comments