Skip to content

Commit a5c7609

Browse files
authored
Single location to update optional args for all attentions
Differential Revision: D68988021 Pull Request resolved: #8128
1 parent e92bb7a commit a5c7609

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

examples/models/llama/attention.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Dict, Optional, Tuple, Type
2+
from typing import Any, Dict, Optional, Tuple, Type, TypedDict
33

44
import torch
55
import torch.nn as nn
@@ -8,6 +8,15 @@
88
from executorch.examples.models.llama.rope import Rope
99

1010

11+
class ForwardOptions(TypedDict, total=False):
12+
"""Optional parameters for `Attention.forward` (compative with Python 3.10 and plus)."""
13+
14+
mask: Optional[torch.Tensor]
15+
input_pos: Optional[torch.Tensor]
16+
in_cache_state: Optional[Any]
17+
out_cache_state: Optional[Any]
18+
19+
1120
class Attention(nn.Module, ABC):
1221
"""Abstract base class for attention mechanisms with unified interface."""
1322

@@ -17,19 +26,14 @@ def forward(
1726
x: torch.Tensor,
1827
freqs_cos: torch.Tensor,
1928
freqs_sin: torch.Tensor,
20-
mask: Optional[torch.Tensor] = None,
21-
input_pos: Optional[torch.Tensor] = None,
22-
in_cache_state: Optional[Any] = None,
23-
out_cache_state: Optional[Any] = None,
29+
**kwargs: ForwardOptions,
2430
) -> Tuple[torch.Tensor, Optional[Any]]:
2531
"""Forward pass for attention mechanism.
2632
2733
Args:
2834
x: Input tensor of shape (batch_size, seq_len, dim)
2935
freqs_cos, freqs_sin: Rotary position embedding frequencies
30-
mask: Optional attention mask
31-
input_pos: Positions for KV cache updates
32-
in_cache_state/out_cache_state: Cache states
36+
ForwardOptions: grouped optional args
3337
3438
Returns:
3539
Tuple of (output tensor, updated cache state)
@@ -209,11 +213,9 @@ def forward(
209213
x: torch.Tensor,
210214
freqs_cos: torch.Tensor,
211215
freqs_sin: torch.Tensor,
212-
mask: Optional[torch.Tensor] = None,
213-
input_pos: Optional[torch.Tensor] = None,
214-
in_cache_state: Optional[Any] = None,
215-
out_cache_state: Optional[Any] = None,
216+
**kwargs: ForwardOptions,
216217
) -> Tuple[torch.Tensor, Optional[Any]]:
218+
input_pos = kwargs.get("input_pos")
217219
bsz, seqlen, _ = x.shape
218220

219221
# QKV

0 commit comments

Comments
 (0)