Skip to content

Commit

Permalink
Changes to SDPA to support no kv cache export
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
tarun292 committed Jan 6, 2025
1 parent 68c0208 commit 3979fc8
Showing 1 changed file with 27 additions and 12 deletions.
39 changes: 27 additions & 12 deletions examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Example script for exporting Llama2 to flatbuffer

import math
from typing import Tuple, Union
from typing import Tuple, Union, Optional

import torch

Expand All @@ -22,20 +22,24 @@
class SDPACustom(torch.nn.Module):
def __init__(
self,
kv_cache: Union[KVCache, QuantizedKVCache],
dim: int,
kv_cache: Optional[Union[KVCache, QuantizedKVCache]] = None,
dim: int = -1,
is_causal = True,
):
super().__init__()
# Custom op only supports float32 currently. Converting to/from float32 is
# faster than not having the op.
self.kv_cache = kv_cache
if not isinstance(kv_cache, QuantizedKVCache):
if kv_cache is None:
pass
elif not isinstance(kv_cache, QuantizedKVCache):
self.kv_cache = kv_cache.to(torch.float)
else:
assert (
kv_cache.cache_fp_type == torch.float32
), "Only float32 is supported for custom SDPA"
self.dim = dim
self.is_causal = is_causal

def forward(
self,
Expand All @@ -44,8 +48,8 @@ def forward(
k: torch.Tensor,
v: torch.Tensor,
bsz,
seqlen,
mask,
seqlen = None,
mask = None,
):
# Custom op only supports float32 currently. Converting to/from float32 is
# faster than not having the op.
Expand All @@ -54,9 +58,20 @@ def forward(
k = k.to(dtype=torch.float)
v = v.to(dtype=torch.float)

k_cache = self.kv_cache.k_cache
v_cache = self.kv_cache.v_cache
if hasattr(self.kv_cache, "quantized_cache_dtype"):
k_cache = self.kv_cache.k_cache if self.kv_cache is not None else None
v_cache = self.kv_cache.v_cache if self.kv_cache is not None else None

if self.kv_cache is None:
output = torch.ops.llama.custom_sdpa(
q,
k,
v,
input_pos,
None, # Attention mask
0, # dropout probability. Ignored by the code
self.is_causal, # is_causal
)
elif isinstance(self.kv_cache, QuantizedKVCache):
# updated quantize cache, scale and zero points
# returns dequantized kv cache
# Not most optimal. Optimizations to follow next
Expand All @@ -68,7 +83,7 @@ def forward(
input_pos[0].item(),
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
self.is_causal, # is_causal
)
else:
output = torch.ops.llama.sdpa_with_kv_cache(
Expand All @@ -81,7 +96,7 @@ def forward(
seqlen,
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
self.is_causal, # is_causal
)
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)

Expand All @@ -99,7 +114,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):


def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
from executorch.extension.llm.custom_ops import custom_ops # noqa
from executorch.extension.llm.custom_ops import custom_ops

_replace_sdpa_with_custom_op(module)
return module
Expand Down

0 comments on commit 3979fc8

Please sign in to comment.