1
1
from abc import ABC , abstractmethod
2
- from typing import Any , Dict , Optional , Tuple , Type
2
+ from typing import Any , Dict , Optional , Tuple , Type , TypedDict
3
3
4
4
import torch
5
5
import torch .nn as nn
8
8
from executorch .examples .models .llama .rope import Rope
9
9
10
10
11
+ class ForwardOptions (TypedDict , total = False ):
12
+ """Optional parameters for `Attention.forward` (compative with Python 3.10 and plus)."""
13
+
14
+ mask : Optional [torch .Tensor ]
15
+ input_pos : Optional [torch .Tensor ]
16
+ in_cache_state : Optional [Any ]
17
+ out_cache_state : Optional [Any ]
18
+
19
+
11
20
class Attention (nn .Module , ABC ):
12
21
"""Abstract base class for attention mechanisms with unified interface."""
13
22
@@ -17,19 +26,14 @@ def forward(
17
26
x : torch .Tensor ,
18
27
freqs_cos : torch .Tensor ,
19
28
freqs_sin : torch .Tensor ,
20
- mask : Optional [torch .Tensor ] = None ,
21
- input_pos : Optional [torch .Tensor ] = None ,
22
- in_cache_state : Optional [Any ] = None ,
23
- out_cache_state : Optional [Any ] = None ,
29
+ ** kwargs : ForwardOptions ,
24
30
) -> Tuple [torch .Tensor , Optional [Any ]]:
25
31
"""Forward pass for attention mechanism.
26
32
27
33
Args:
28
34
x: Input tensor of shape (batch_size, seq_len, dim)
29
35
freqs_cos, freqs_sin: Rotary position embedding frequencies
30
- mask: Optional attention mask
31
- input_pos: Positions for KV cache updates
32
- in_cache_state/out_cache_state: Cache states
36
+ ForwardOptions: grouped optional args
33
37
34
38
Returns:
35
39
Tuple of (output tensor, updated cache state)
@@ -209,11 +213,9 @@ def forward(
209
213
x : torch .Tensor ,
210
214
freqs_cos : torch .Tensor ,
211
215
freqs_sin : torch .Tensor ,
212
- mask : Optional [torch .Tensor ] = None ,
213
- input_pos : Optional [torch .Tensor ] = None ,
214
- in_cache_state : Optional [Any ] = None ,
215
- out_cache_state : Optional [Any ] = None ,
216
+ ** kwargs : ForwardOptions ,
216
217
) -> Tuple [torch .Tensor , Optional [Any ]]:
218
+ input_pos = kwargs .get ("input_pos" )
217
219
bsz , seqlen , _ = x .shape
218
220
219
221
# QKV
0 commit comments