diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 59bfbe6f95..eff6ee5aec 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -9,7 +9,7 @@ # Example script for exporting Llama2 to flatbuffer import math -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch @@ -22,20 +22,24 @@ class SDPACustom(torch.nn.Module): def __init__( self, - kv_cache: Union[KVCache, QuantizedKVCache], - dim: int, + kv_cache: Optional[Union[KVCache, QuantizedKVCache]] = None, + dim: int = -1, + is_causal=True, ): super().__init__() # Custom op only supports float32 currently. Converting to/from float32 is # faster than not having the op. self.kv_cache = kv_cache - if not isinstance(kv_cache, QuantizedKVCache): + if kv_cache is None: + pass + elif not isinstance(kv_cache, QuantizedKVCache): self.kv_cache = kv_cache.to(torch.float) else: assert ( kv_cache.cache_fp_type == torch.float32 ), "Only float32 is supported for custom SDPA" self.dim = dim + self.is_causal = is_causal def forward( self, @@ -44,8 +48,8 @@ def forward( k: torch.Tensor, v: torch.Tensor, bsz, - seqlen, - mask, + seqlen=None, + mask=None, ): # Custom op only supports float32 currently. Converting to/from float32 is # faster than not having the op. @@ -54,9 +58,20 @@ def forward( k = k.to(dtype=torch.float) v = v.to(dtype=torch.float) - k_cache = self.kv_cache.k_cache - v_cache = self.kv_cache.v_cache - if hasattr(self.kv_cache, "quantized_cache_dtype"): + k_cache = self.kv_cache.k_cache if self.kv_cache is not None else None + v_cache = self.kv_cache.v_cache if self.kv_cache is not None else None + + if self.kv_cache is None: + output = torch.ops.llama.custom_sdpa( + q, + k, + v, + input_pos, + None, # Attention mask + 0, # dropout probability. Ignored by the code + self.is_causal, # is_causal + ) + elif isinstance(self.kv_cache, QuantizedKVCache): # updated quantize cache, scale and zero points # returns dequantized kv cache # Not most optimal. Optimizations to follow next @@ -68,7 +83,7 @@ def forward( input_pos[0].item(), None, # Attention mask 0, # dropout probability. Ignored by the code - True, # is_causal + self.is_causal, # is_causal ) else: output = torch.ops.llama.sdpa_with_kv_cache( @@ -81,7 +96,7 @@ def forward( seqlen, None, # Attention mask 0, # dropout probability. Ignored by the code - True, # is_causal + self.is_causal, # is_causal ) return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) @@ -99,7 +114,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module): def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: - from executorch.extension.llm.custom_ops import custom_ops # noqa + from executorch.extension.llm.custom_ops import custom_ops _replace_sdpa_with_custom_op(module) return module