9
9
# Example script for exporting Llama2 to flatbuffer
10
10
11
11
import math
12
- from typing import Tuple , Union
12
+ from typing import Tuple , Union , Optional
13
13
14
14
import torch
15
15
22
22
class SDPACustom (torch .nn .Module ):
23
23
def __init__ (
24
24
self ,
25
- kv_cache : Union [KVCache , QuantizedKVCache ],
26
- dim : int ,
25
+ kv_cache : Optional [Union [KVCache , QuantizedKVCache ]] = None ,
26
+ dim : int = - 1 ,
27
+ is_causal = True ,
27
28
):
28
29
super ().__init__ ()
29
30
# Custom op only supports float32 currently. Converting to/from float32 is
30
31
# faster than not having the op.
31
32
self .kv_cache = kv_cache
32
- if not isinstance (kv_cache , QuantizedKVCache ):
33
+ if kv_cache is None :
34
+ pass
35
+ elif not isinstance (kv_cache , QuantizedKVCache ):
33
36
self .kv_cache = kv_cache .to (torch .float )
34
37
else :
35
38
assert (
36
39
kv_cache .cache_fp_type == torch .float32
37
40
), "Only float32 is supported for custom SDPA"
38
41
self .dim = dim
42
+ self .is_causal = is_causal
39
43
40
44
def forward (
41
45
self ,
@@ -44,8 +48,8 @@ def forward(
44
48
k : torch .Tensor ,
45
49
v : torch .Tensor ,
46
50
bsz ,
47
- seqlen ,
48
- mask ,
51
+ seqlen = None ,
52
+ mask = None ,
49
53
):
50
54
# Custom op only supports float32 currently. Converting to/from float32 is
51
55
# faster than not having the op.
@@ -54,9 +58,20 @@ def forward(
54
58
k = k .to (dtype = torch .float )
55
59
v = v .to (dtype = torch .float )
56
60
57
- k_cache = self .kv_cache .k_cache
58
- v_cache = self .kv_cache .v_cache
59
- if hasattr (self .kv_cache , "quantized_cache_dtype" ):
61
+ k_cache = self .kv_cache .k_cache if self .kv_cache is not None else None
62
+ v_cache = self .kv_cache .v_cache if self .kv_cache is not None else None
63
+
64
+ if self .kv_cache is None :
65
+ output = torch .ops .llama .custom_sdpa (
66
+ q ,
67
+ k ,
68
+ v ,
69
+ input_pos ,
70
+ None , # Attention mask
71
+ 0 , # dropout probability. Ignored by the code
72
+ self .is_causal , # is_causal
73
+ )
74
+ elif isinstance (self .kv_cache , QuantizedKVCache ):
60
75
# updated quantize cache, scale and zero points
61
76
# returns dequantized kv cache
62
77
# Not most optimal. Optimizations to follow next
@@ -68,7 +83,7 @@ def forward(
68
83
input_pos [0 ].item (),
69
84
None , # Attention mask
70
85
0 , # dropout probability. Ignored by the code
71
- True , # is_causal
86
+ self . is_causal , # is_causal
72
87
)
73
88
else :
74
89
output = torch .ops .llama .sdpa_with_kv_cache (
@@ -81,7 +96,7 @@ def forward(
81
96
seqlen ,
82
97
None , # Attention mask
83
98
0 , # dropout probability. Ignored by the code
84
- True , # is_causal
99
+ self . is_causal , # is_causal
85
100
)
86
101
return output .view (bsz , seqlen , self .dim ).to (dtype = input_dtype )
87
102
@@ -99,7 +114,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
99
114
100
115
101
116
def replace_sdpa_with_custom_op (module : torch .nn .Module ) -> torch .nn .Module :
102
- from executorch .extension .llm .custom_ops import custom_ops # noqa
117
+ from executorch .extension .llm .custom_ops import custom_ops
103
118
104
119
_replace_sdpa_with_custom_op (module )
105
120
return module
0 commit comments