Skip to content

Commit

Permalink
Add utils to replace torchtune SDPA with ET Custom SDPA
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
tarun292 committed Jan 6, 2025
1 parent 3979fc8 commit 158779f
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion extension/llm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch import nn
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
from torchtune.modules.kv_cache import KVCache
from executorch.examples.models.llama.source_transformation.sdpa import SDPACustom

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -310,7 +311,9 @@ def false_fn(y):
self.kv_cache.v_cache.copy_(v)
self.kv_cache.cache_pos.copy_(cache_pos)

output = self._sdpa(q, k, v, b, s_x, mask=mask)
if input_pos is None:
input_pos = torch.tensor(0)
output = self._sdpa(input_pos, q, k, v, b, s_x, mask=mask)
return self.output_proj(output)


Expand Down Expand Up @@ -364,6 +367,7 @@ def forward(
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)


output = self._attention_fn(
q,
k,
Expand Down Expand Up @@ -411,3 +415,21 @@ def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module:
"""
_replace_mha_with_inference_mha(module)
return module


def _replace_sdpa_with_custom_op(module: torch.nn.Module):
for name, child in module.named_children():
if isinstance(child, SDPA):
setattr(
module,
name,
SDPACustom(is_causal=child.is_causal),
)
else:
_replace_sdpa_with_custom_op(child)


def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
from executorch.extension.llm.custom_ops import custom_ops
_replace_sdpa_with_custom_op(module)
return module

0 comments on commit 158779f

Please sign in to comment.