Skip to content

Commit 3979fc8

Browse files
committed
Changes to SDPA to support no kv cache export
[ghstack-poisoned]
1 parent 68c0208 commit 3979fc8

File tree

1 file changed

+27
-12
lines changed
  • examples/models/llama/source_transformation

1 file changed

+27
-12
lines changed

examples/models/llama/source_transformation/sdpa.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# Example script for exporting Llama2 to flatbuffer
1010

1111
import math
12-
from typing import Tuple, Union
12+
from typing import Tuple, Union, Optional
1313

1414
import torch
1515

@@ -22,20 +22,24 @@
2222
class SDPACustom(torch.nn.Module):
2323
def __init__(
2424
self,
25-
kv_cache: Union[KVCache, QuantizedKVCache],
26-
dim: int,
25+
kv_cache: Optional[Union[KVCache, QuantizedKVCache]] = None,
26+
dim: int = -1,
27+
is_causal = True,
2728
):
2829
super().__init__()
2930
# Custom op only supports float32 currently. Converting to/from float32 is
3031
# faster than not having the op.
3132
self.kv_cache = kv_cache
32-
if not isinstance(kv_cache, QuantizedKVCache):
33+
if kv_cache is None:
34+
pass
35+
elif not isinstance(kv_cache, QuantizedKVCache):
3336
self.kv_cache = kv_cache.to(torch.float)
3437
else:
3538
assert (
3639
kv_cache.cache_fp_type == torch.float32
3740
), "Only float32 is supported for custom SDPA"
3841
self.dim = dim
42+
self.is_causal = is_causal
3943

4044
def forward(
4145
self,
@@ -44,8 +48,8 @@ def forward(
4448
k: torch.Tensor,
4549
v: torch.Tensor,
4650
bsz,
47-
seqlen,
48-
mask,
51+
seqlen = None,
52+
mask = None,
4953
):
5054
# Custom op only supports float32 currently. Converting to/from float32 is
5155
# faster than not having the op.
@@ -54,9 +58,20 @@ def forward(
5458
k = k.to(dtype=torch.float)
5559
v = v.to(dtype=torch.float)
5660

57-
k_cache = self.kv_cache.k_cache
58-
v_cache = self.kv_cache.v_cache
59-
if hasattr(self.kv_cache, "quantized_cache_dtype"):
61+
k_cache = self.kv_cache.k_cache if self.kv_cache is not None else None
62+
v_cache = self.kv_cache.v_cache if self.kv_cache is not None else None
63+
64+
if self.kv_cache is None:
65+
output = torch.ops.llama.custom_sdpa(
66+
q,
67+
k,
68+
v,
69+
input_pos,
70+
None, # Attention mask
71+
0, # dropout probability. Ignored by the code
72+
self.is_causal, # is_causal
73+
)
74+
elif isinstance(self.kv_cache, QuantizedKVCache):
6075
# updated quantize cache, scale and zero points
6176
# returns dequantized kv cache
6277
# Not most optimal. Optimizations to follow next
@@ -68,7 +83,7 @@ def forward(
6883
input_pos[0].item(),
6984
None, # Attention mask
7085
0, # dropout probability. Ignored by the code
71-
True, # is_causal
86+
self.is_causal, # is_causal
7287
)
7388
else:
7489
output = torch.ops.llama.sdpa_with_kv_cache(
@@ -81,7 +96,7 @@ def forward(
8196
seqlen,
8297
None, # Attention mask
8398
0, # dropout probability. Ignored by the code
84-
True, # is_causal
99+
self.is_causal, # is_causal
85100
)
86101
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
87102

@@ -99,7 +114,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
99114

100115

101116
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
102-
from executorch.extension.llm.custom_ops import custom_ops # noqa
117+
from executorch.extension.llm.custom_ops import custom_ops
103118

104119
_replace_sdpa_with_custom_op(module)
105120
return module

0 commit comments

Comments
 (0)