From f6611f9d3a1c4f9bde75d9730571c5f4e1f64229 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 15 Oct 2024 16:08:20 +0800 Subject: [PATCH] optimize llama3.2 vison attention again (#12204) --- .../ipex_llm/transformers/models/mllama.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/mllama.py b/python/llm/src/ipex_llm/transformers/models/mllama.py index 69fae966b44..9752ebe9e36 100644 --- a/python/llm/src/ipex_llm/transformers/models/mllama.py +++ b/python/llm/src/ipex_llm/transformers/models/mllama.py @@ -36,6 +36,7 @@ import torch from typing import Optional +from ipex_llm.transformers.models.utils import use_sdp_non_causal def mllama_vision_attention_forward( @@ -55,17 +56,27 @@ def mllama_vision_attention_forward( key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) - attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key.shape[-2]] - attn_weights = attn_weights + causal_mask + else: + causal_mask = None + + if use_sdp_non_causal(self.head_dim, query.device, query.dtype): + import xe_addons + attn_output = xe_addons.sdp_non_causal(query, key.contiguous(), + value.contiguous(), causal_mask) + attn_weights = None + else: + attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = attn_weights + causal_mask - # upcast attention to fp32 - from ipex_llm.transformers.models.common import attention_softmax - attn_weights = attention_softmax(attn_weights, self.training) + # upcast attention to fp32 + from ipex_llm.transformers.models.common import attention_softmax + attn_weights = attention_softmax(attn_weights, False) - attn_output = torch.matmul(attn_weights, value) + attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, q_seq_len, -1)