Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to SDPA to support no kv cache export #7530

Open
wants to merge 2 commits into
base: gh/tarun292/1/base
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Example script for exporting Llama2 to flatbuffer

import math
from typing import Tuple, Union
from typing import Tuple, Union, Optional

import torch

Expand All @@ -22,20 +22,24 @@
class SDPACustom(torch.nn.Module):
def __init__(
self,
kv_cache: Union[KVCache, QuantizedKVCache],
dim: int,
kv_cache: Optional[Union[KVCache, QuantizedKVCache]] = None,
dim: int = -1,
is_causal = True,
):
super().__init__()
# Custom op only supports float32 currently. Converting to/from float32 is
# faster than not having the op.
self.kv_cache = kv_cache
if not isinstance(kv_cache, QuantizedKVCache):
if kv_cache is None:
pass
elif not isinstance(kv_cache, QuantizedKVCache):
self.kv_cache = kv_cache.to(torch.float)
else:
assert (
kv_cache.cache_fp_type == torch.float32
), "Only float32 is supported for custom SDPA"
self.dim = dim
self.is_causal = is_causal

def forward(
self,
Expand All @@ -44,8 +48,8 @@ def forward(
k: torch.Tensor,
v: torch.Tensor,
bsz,
seqlen,
mask,
seqlen = None,
mask = None,
):
# Custom op only supports float32 currently. Converting to/from float32 is
# faster than not having the op.
Expand All @@ -54,9 +58,20 @@ def forward(
k = k.to(dtype=torch.float)
v = v.to(dtype=torch.float)

k_cache = self.kv_cache.k_cache
v_cache = self.kv_cache.v_cache
if hasattr(self.kv_cache, "quantized_cache_dtype"):
k_cache = self.kv_cache.k_cache if self.kv_cache is not None else None
v_cache = self.kv_cache.v_cache if self.kv_cache is not None else None

if self.kv_cache is None:
output = torch.ops.llama.custom_sdpa(
q,
k,
v,
input_pos,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh I think you can just set this to 0, should work for the no kv cache text decoder as well since it represents start position, so you don't need to set input_pos to torch.tensor(0) in your other pr

None, # Attention mask
0, # dropout probability. Ignored by the code
self.is_causal, # is_causal
)
elif isinstance(self.kv_cache, QuantizedKVCache):
# updated quantize cache, scale and zero points
# returns dequantized kv cache
# Not most optimal. Optimizations to follow next
Expand All @@ -68,7 +83,7 @@ def forward(
input_pos[0].item(),
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
self.is_causal, # is_causal
)
else:
output = torch.ops.llama.sdpa_with_kv_cache(
Expand All @@ -81,7 +96,7 @@ def forward(
seqlen,
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
self.is_causal, # is_causal
)
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)

Expand All @@ -99,7 +114,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):


def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
from executorch.extension.llm.custom_ops import custom_ops # noqa
from executorch.extension.llm.custom_ops import custom_ops

_replace_sdpa_with_custom_op(module)
return module
Expand Down
Loading