Skip to content

Commit

Permalink
optimize llama3.2 vison attention again (intel-analytics#12204)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Oct 15, 2024
1 parent 9b81236 commit f6611f9
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions python/llm/src/ipex_llm/transformers/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import torch

from typing import Optional
from ipex_llm.transformers.models.utils import use_sdp_non_causal


def mllama_vision_attention_forward(
Expand All @@ -55,17 +56,27 @@ def mllama_vision_attention_forward(
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)

attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
else:
causal_mask = None

if use_sdp_non_causal(self.head_dim, query.device, query.dtype):
import xe_addons
attn_output = xe_addons.sdp_non_causal(query, key.contiguous(),
value.contiguous(), causal_mask)
attn_weights = None
else:
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None:
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
from ipex_llm.transformers.models.common import attention_softmax
attn_weights = attention_softmax(attn_weights, self.training)
# upcast attention to fp32
from ipex_llm.transformers.models.common import attention_softmax
attn_weights = attention_softmax(attn_weights, False)

attn_output = torch.matmul(attn_weights, value)
attn_output = torch.matmul(attn_weights, value)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
Expand Down

0 comments on commit f6611f9

Please sign in to comment.