From 158779f61e05012fb208421ac58445bde9f49b16 Mon Sep 17 00:00:00 2001 From: Tarun Karuturi Date: Mon, 6 Jan 2025 13:53:06 -0800 Subject: [PATCH] Add utils to replace torchtune SDPA with ET Custom SDPA [ghstack-poisoned] --- extension/llm/modules/attention.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/extension/llm/modules/attention.py b/extension/llm/modules/attention.py index 60183801b4..232f106b38 100644 --- a/extension/llm/modules/attention.py +++ b/extension/llm/modules/attention.py @@ -13,6 +13,7 @@ from torch import nn from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention from torchtune.modules.kv_cache import KVCache +from executorch.examples.models.llama.source_transformation.sdpa import SDPACustom logger = logging.getLogger(__name__) @@ -310,7 +311,9 @@ def false_fn(y): self.kv_cache.v_cache.copy_(v) self.kv_cache.cache_pos.copy_(cache_pos) - output = self._sdpa(q, k, v, b, s_x, mask=mask) + if input_pos is None: + input_pos = torch.tensor(0) + output = self._sdpa(input_pos, q, k, v, b, s_x, mask=mask) return self.output_proj(output) @@ -364,6 +367,7 @@ def forward( k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2) v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2) + output = self._attention_fn( q, k, @@ -411,3 +415,21 @@ def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module: """ _replace_mha_with_inference_mha(module) return module + + +def _replace_sdpa_with_custom_op(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, SDPA): + setattr( + module, + name, + SDPACustom(is_causal=child.is_causal), + ) + else: + _replace_sdpa_with_custom_op(child) + + +def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: + from executorch.extension.llm.custom_ops import custom_ops + _replace_sdpa_with_custom_op(module) + return module