diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index 409a4d1210bc3..fc363585b0e7e 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -410,7 +410,7 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
- * `Phi3ForCausalLM`
* Phi-4, Phi-3
- * `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc.
+ * `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc.
* ✅︎
* ✅︎
- * `Phi3SmallForCausalLM`
@@ -856,6 +856,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
+- * `Phi4MMForCausalLM`
+ * Phi-4-multimodal
+ * T + I+ / T + A+ / I+ + A+
+ * `microsoft/Phi-4-multimodal-instruct`, etc.
+ * ✅︎
+ *
+ *
- * `PixtralForConditionalGeneration`
* Pixtral
* T + I+
diff --git a/requirements-common.txt b/requirements-common.txt
index da6bc2af68cf2..1ca64c9948974 100644
--- a/requirements-common.txt
+++ b/requirements-common.txt
@@ -37,3 +37,4 @@ depyf==0.18.0 # required for profiling and debugging with compilation config
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
watchfiles # required for http server to monitor the updates of TLS files
python-json-logger # Used by logging as per examples/other/logging_configuration.md
+scipy # Required for phi-4-multimodal-instruct
\ No newline at end of file
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 97db33b46fade..3c3247eaf3e99 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -272,6 +272,8 @@ def check_available_online(
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
trust_remote_code=True),
+ "Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
+ trust_remote_code=True),
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
tokenizer_mode="mistral"),
"QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL",
diff --git a/vllm/config.py b/vllm/config.py
index f87d2d6e82cf8..3f1bff4981294 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -2284,9 +2284,9 @@ def compute_hash(self) -> str:
return hash_str
def __post_init__(self):
- # Setting the maximum rank to 256 should be able to satisfy the vast
+ # Setting the maximum rank to 512 should be able to satisfy the vast
# majority of applications.
- possible_max_ranks = (8, 16, 32, 64, 128, 256)
+ possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512)
possible_lora_extra_vocab_size = (0, 256, 512)
if self.max_lora_rank not in possible_max_ranks:
raise ValueError(
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index b05842dd27d3b..8f906cf1d80b6 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -395,6 +395,8 @@ def _placeholder_str(self, modality: ModalityStr,
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return f"<|image_{current_count}|>"
+ if model_type == "phi4mm":
+ return "<|endoftext10|>" # 200010 (see vocab.json in hf model)
if model_type in ("minicpmo", "minicpmv"):
return "(./)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
@@ -424,6 +426,8 @@ def _placeholder_str(self, modality: ModalityStr,
elif modality == "audio":
if model_type == "ultravox":
return "<|audio|>"
+ if model_type == "phi4mm":
+ return "<|endoftext11|>" # 200011 (see vocab.json in hf model)
if model_type == "qwen2_audio":
return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py
new file mode 100644
index 0000000000000..27ae9bcca2e43
--- /dev/null
+++ b/vllm/model_executor/models/phi4mm.py
@@ -0,0 +1,1803 @@
+# SPDX-License-Identifier: Apache-2.0
+import math
+import re
+from functools import lru_cache
+from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Tuple,
+ TypedDict, Union)
+
+import numpy as np
+import scipy.signal
+import torch
+import torch.nn as nn
+import torchvision.transforms as T
+from PIL import Image
+from transformers import PretrainedConfig
+from transformers.utils import logging
+
+from vllm.config import VllmConfig
+from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
+ InputContext)
+from vllm.inputs.data import TokenInputs, token_inputs
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
+from vllm.model_executor.models.llama import LlamaModel
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import MultiModalInputs, NestedTensors
+from vllm.sequence import IntermediateTensors, SequenceData
+from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
+
+from .interfaces import SupportsLoRA, SupportsMultiModal
+from .phi4mm_audio import AudioEmbedding
+from .utils import maybe_prefix
+from .vision_siglip_navit import get_siglip_vision_model
+
+# <|endoftext10|> (see vocab.json in hf model)
+_IMAGE_PLACEHOLDER_TOKEN_ID = 200010
+# <|endoftext11|>
+_AUDIO_PLACEHOLDER_TOKEN_ID = 200011
+
+_AUDIO_MAX_SOUNDFILE_SIZE = 241_000
+DUMMY_SAMPLING_FREQUENCY = 16_000 # kHz
+
+DYNAMIC_HD = 16
+AUDIO_TOKEN_PATTERN = r"<\|audio_(\d+)\|>"
+IMAGE_TOKEN_PATTERN = r"<\|image_(\d+)\|>"
+
+SIGLIP_NAME = "siglip-so400m-patch14-448"
+VISION_ENCODER_TO_PROCESSING_CONFIG = {
+ 'siglip-so400m-patch14-448': {
+ 'dynamic_hd': 16,
+ 'vit_image_size': 448,
+ 'vit_patch_size': 14,
+ 'token_compression_factor': 2,
+ },
+}
+logger = logging.get_logger(__name__)
+# This is a workaround to prevent text (user input) + audio + image
+# from being used in the same prompt.
+# It includes token ids for "/n" and tokens in added_tokens_decoder
+# from the tokenizer_confg.json file.
+NON_USER_INPUT_TOKENS = {
+ 198, 200010, 200011, 199999, 200018, 200019, 200020, 200021, 200022,
+ 200023, 200024, 200025, 200026, 200027, 200028
+}
+
+
+def get_max_dummy_image(ctx: InputContext):
+ hf_config = ctx.get_hf_config()
+ vision_encoder_name = hf_config.img_processor
+ if vision_encoder_name is None:
+ vision_encoder_name = SIGLIP_NAME
+ prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
+ dynamic_hd_size = prepro_config['dynamic_hd']
+ vit_image_size = prepro_config['vit_image_size']
+
+ max_side = vit_image_size * dynamic_hd_size
+ dummy_image = dummy_image_for_phi4mm(vit_image_size, max_side)
+ return dummy_image
+
+
+# image token length
+def get_max_phi4mm_image_tokens(ctx: InputContext):
+ dummy_image = get_max_dummy_image(ctx)
+
+ hf_config = ctx.get_hf_config()
+ vision_encoder_name = hf_config.img_processor
+ if vision_encoder_name is None:
+ vision_encoder_name = SIGLIP_NAME
+ prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
+ dynamic_hd_size = prepro_config['dynamic_hd']
+ vit_image_size = prepro_config['vit_image_size']
+ vit_patch_size = prepro_config['vit_patch_size']
+ token_compression_factor = prepro_config['token_compression_factor']
+
+ image_num_tokens = _compute_num_image_tokens(dummy_image, dynamic_hd_size,
+ vit_image_size,
+ vit_patch_size,
+ token_compression_factor)
+ return image_num_tokens
+
+
+def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
+ image_size):
+ best_ratio_diff = float('inf')
+ best_ratio = (1, 1)
+ area = width * height
+ for ratio in target_ratios:
+ target_aspect_ratio = ratio[0] / ratio[1]
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
+ if ratio_diff < best_ratio_diff:
+ best_ratio_diff = ratio_diff
+ best_ratio = ratio
+ elif ratio_diff == best_ratio_diff:
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
+ best_ratio = ratio
+ return best_ratio
+
+
+def _find_target_aspect_ratio(image, image_size, max_num, min_num):
+ orig_width, orig_height = image.size
+
+ w_crop_num = math.ceil(orig_width / float(image_size))
+ h_crop_num = math.ceil(orig_height / float(image_size))
+ if w_crop_num * h_crop_num > max_num:
+ aspect_ratio = orig_width / orig_height
+
+ # calculate the existing image aspect ratio
+ target_ratios = set((i, j) for i in range(1, max_num + 1)
+ for j in range(1, max_num + 1)
+ if i * j <= max_num and i * j >= min_num)
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+ # find the closest aspect ratio to the target
+ target_aspect_ratio = find_closest_aspect_ratio(
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
+
+ # calculate the target width and height
+ target_width = image_size * target_aspect_ratio[0]
+ target_height = image_size * target_aspect_ratio[1]
+ logger.debug("target_aspect_ratio: %s", target_aspect_ratio)
+ else:
+ target_width = image_size * w_crop_num
+ target_height = image_size * h_crop_num
+ target_aspect_ratio = (w_crop_num, h_crop_num)
+ return target_aspect_ratio, target_height, target_width
+
+
+def _get_padding_size(image, target_height, target_width):
+ orig_width, orig_height = image.size
+ ratio_width = target_width / orig_width
+ ratio_height = target_height / orig_height
+
+ if ratio_width < ratio_height:
+ padding_width = 0
+ padding_height = target_height - int(orig_height * ratio_width)
+ else:
+ padding_width = target_width - int(orig_width * ratio_height)
+ padding_height = 0
+ return padding_height, padding_width
+
+
+def dynamic_preprocess(image,
+ min_num=1,
+ max_num=12,
+ image_size=384,
+ mask_size=27):
+ target_aspect_ratio, target_height, target_width =\
+ _find_target_aspect_ratio(
+ image, image_size, max_num, min_num)
+ padding_height, padding_width = _get_padding_size(image, target_height,
+ target_width)
+
+ # Calculate the ratio
+ orig_width, orig_height = image.size
+ ratio_width = target_width / orig_width
+ ratio_height = target_height / orig_height
+ if ratio_width < ratio_height:
+ new_size = (target_width, int(orig_height * ratio_width))
+ else:
+ new_size = (int(orig_width * ratio_height), target_height)
+
+ attention_mask = torch.ones((int(mask_size * target_aspect_ratio[1]),
+ int(mask_size * target_aspect_ratio[0])))
+ if padding_width >= 14:
+ attention_mask[:, -math.floor(padding_width / 14):] = 0
+ if padding_height >= 14:
+ attention_mask[-math.floor(padding_height / 14):, :] = 0
+ assert attention_mask.sum(
+ ) > 0, f'attention mask is empty {attention_mask}'
+
+ if min(new_size[1], target_height) < 10 or min(new_size[0],
+ target_width) < 10:
+ raise ValueError(f'the aspect ratio is very extreme {new_size}')
+
+ image = T.functional.resize(
+ image,
+ [new_size[1], new_size[0]],
+ )
+
+ resized_img = T.functional.pad(image,
+ [0, 0, padding_width, padding_height],
+ fill=[255, 255, 255])
+
+ return resized_img, attention_mask
+
+
+def pad_to_max_num_crops(images, max_crops=5):
+ """
+ images: B x 3 x H x W, B<=max_crops
+ """
+ B, _, H, W = images.shape
+ if max_crops > B:
+ pad = torch.zeros(max_crops - B,
+ 3,
+ H,
+ W,
+ dtype=images.dtype,
+ device=images.device)
+ images = torch.cat([images, pad], dim=0)
+ return images
+
+
+def pad_mask_to_max_num_crops(masks, max_crops=5):
+ B, H, W = masks.shape
+ if max_crops > B:
+ pad = torch.ones(max_crops - B,
+ H,
+ W,
+ dtype=masks.dtype,
+ device=masks.device)
+ masks = torch.cat([masks, pad], dim=0)
+ return masks
+
+
+def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size):
+
+ # Basic settings.
+ img_processor = T.Compose([
+ T.ToTensor(),
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+ ])
+ # Dynamic HD
+ base_resolution = vit_resolution
+ images = [image.convert('RGB') for image in images]
+ # cover 384 and 448 resolution
+ mask_resolution = base_resolution // vit_patch_size
+ elems, image_attention_masks = [], []
+ for im in images:
+ elem, attention_mask = dynamic_preprocess(im,
+ max_num=dynamic_hd_size,
+ image_size=base_resolution,
+ mask_size=mask_resolution)
+ elems.append(elem)
+ image_attention_masks.append(attention_mask)
+ hd_images = [img_processor(im) for im in elems]
+ global_image = [
+ torch.nn.functional.interpolate(
+ im.unsqueeze(0).float(),
+ size=(base_resolution, base_resolution),
+ mode='bicubic',
+ ).to(im.dtype) for im in hd_images
+ ]
+ shapes = [[im.size(1), im.size(2)] for im in hd_images]
+ mask_shapes = [[mask.size(0), mask.size(1)]
+ for mask in image_attention_masks]
+ global_attention_mask = [
+ torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images
+ ]
+ hd_images_reshape = [
+ im.reshape(1, 3, h // base_resolution, base_resolution,
+ w // base_resolution, base_resolution).permute(
+ 0, 2, 4, 1, 3, 5).reshape(-1, 3, base_resolution,
+ base_resolution).contiguous()
+ for im, (h, w) in zip(hd_images, shapes)
+ ]
+ attention_masks_reshape = [
+ mask.reshape(1, h // mask_resolution, mask_resolution,
+ w // mask_resolution, mask_resolution).permute(
+ 0, 1, 3, 2, 4).reshape(-1, mask_resolution,
+ mask_resolution).contiguous()
+ for mask, (h, w) in zip(image_attention_masks, mask_shapes)
+ ]
+ # NOTE token compression is hard coded here, and odd numbers seems to fail
+ downsample_attention_masks = [
+ mask[:, 0::2,
+ 0::2].reshape(1, h // mask_resolution, w // mask_resolution,
+ mask_resolution // 2 + mask_resolution % 2,
+ mask_resolution // 2 + mask_resolution % 2).permute(
+ 0, 1, 3, 2, 4)
+ for mask, (h, w) in zip(attention_masks_reshape, mask_shapes)
+ ]
+ downsample_attention_masks = [
+ mask.reshape(mask.size(1) * mask.size(2),
+ mask.size(3) * mask.size(4))
+ for mask in downsample_attention_masks
+ ]
+ # NOTE hard coded number of tokens
+ num_img_tokens = [
+ 256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16
+ for mask in downsample_attention_masks
+ ]
+
+ hd_images_reshape = [
+ torch.cat([_global_image] + [_im], dim=0)
+ for _global_image, _im in zip(global_image, hd_images_reshape)
+ ]
+ hd_masks_reshape = [
+ torch.cat([_global_mask] + [_mask],
+ dim=0) for _global_mask, _mask in zip(
+ global_attention_mask, attention_masks_reshape)
+ ]
+ max_crops = max([img.size(0) for img in hd_images_reshape])
+ image_transformed = [
+ pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape
+ ]
+ image_transformed = torch.stack(image_transformed, dim=0)
+ mask_transformed = [
+ pad_mask_to_max_num_crops(mask, max_crops) \
+ for mask in hd_masks_reshape
+ ]
+ mask_transformed = torch.stack(mask_transformed, dim=0)
+
+ returned_input_image_embeds = image_transformed
+ returned_image_sizes = torch.tensor(shapes, dtype=torch.long)
+ returned_image_attention_mask = mask_transformed
+ returned_num_img_tokens = num_img_tokens
+
+ data = {
+ "pixel_values": returned_input_image_embeds,
+ "image_sizes": returned_image_sizes,
+ "image_attention_mask": returned_image_attention_mask,
+ "num_img_tokens": returned_num_img_tokens,
+ }
+ return data
+
+
+class Phi4MMImageEncoder(nn.Module):
+ """Image embedding."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig],
+ prefix: str = "",
+ model_dir: str = "") -> None:
+ super().__init__()
+
+ # n_embed or hidden_size
+ hidden_size = config.n_embd if hasattr(
+ config, 'n_embd') else config.hidden_size
+ if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'):
+ embd_drop = config.embd_pdrop if hasattr(
+ config, 'embd_pdrop') else config.embed_pdrop
+ self.drop = nn.Dropout(embd_drop)
+ else:
+ self.drop = None
+
+ # layer_idx to output the img features
+ if isinstance(config.img_processor, dict):
+ self.layer_idx = config.img_processor.get('layer_idx', -2)
+ self.type_feature = config.img_processor.get(
+ 'type_feature', 'patch')
+ else:
+ self.layer_idx = -2
+ self.type_feature = 'patch'
+
+ self.img_processor = get_siglip_vision_model(
+ _flash_attn_2_enabled=True)
+
+ pe_weight = self.img_processor.embeddings.position_embedding.weight
+ L, D = pe_weight.size()
+ H = int(math.sqrt(L))
+ assert H**2 == L, f'position embedding size {L} is not square'
+ if H % 2 != 0:
+ self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1))
+ H += 1
+ image_dim_out = D
+ # ((448/14)//2)**2
+ self.num_img_tokens = (H // 2)**2
+ self.base_feat_height_target = H
+
+ self.image_dim_out = image_dim_out
+ self.img_sizes = None
+ self.image_attention_mask = None
+
+ # global_gn and sub_gn for hd transform, serves as line separator
+ self.use_hd_transform = True
+ self.with_learnable_separator = True
+ self.hd_transform_order = "sub_glb"
+ self.freeze_img_processor = False
+ self.crop_size = 448
+
+ # image token compression
+ self.image_token_compression_cls = 'avg_pool_2d'
+ self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2)
+ self.base_feat_height_reduction = 1
+ self.base_feat_height_target = self.base_feat_height_target // 2
+
+ # with_hd_transform and with_learnable_separator should have same value
+ assert self.use_hd_transform == self.with_learnable_separator, \
+ 'use_hd_transform and with_learnable_separator should have same value'
+ assert self.use_hd_transform, \
+ 'learnable separator is only for hd transform'
+ # 1024 * 4, merge spatial to channel dimension
+ self.glb_GN = nn.Parameter(
+ torch.zeros([
+ 1, 1, self.image_dim_out * self.base_feat_height_reduction**2
+ ]))
+ self.sub_GN = nn.Parameter(
+ torch.zeros([
+ 1, 1, 1,
+ self.image_dim_out * self.base_feat_height_reduction**2
+ ]))
+
+ dim_projection = hidden_size
+ depth = 2
+ layers = [
+ nn.Linear(image_dim_out * self.base_feat_height_reduction**2,
+ dim_projection)
+ ]
+ for _ in range(1, depth):
+ layers.extend(
+ [nn.GELU(),
+ nn.Linear(dim_projection, dim_projection)])
+ self.img_projection = nn.Sequential(*layers)
+
+ self.vocab_size = config.vocab_size
+ self.img_features = None
+
+ self.use_out_place_operations = False
+
+ def get_img_features(self,
+ img_embeds: torch.FloatTensor,
+ attention_mask=None) -> torch.FloatTensor:
+ LAYER_IDX = self.layer_idx
+ TYPE_FEATURE = self.type_feature
+
+ img_processor_output = self.img_processor(
+ img_embeds,
+ output_hidden_states=True,
+ patch_attention_mask=attention_mask)
+ img_feature = img_processor_output.hidden_states[LAYER_IDX]
+
+ if TYPE_FEATURE == "patch":
+ patch_feature = img_feature
+
+ use_token_compression = self.image_token_compression is not None
+ use_padding = getattr(self, 'img_processor_padding',
+ None) is not None
+ if use_token_compression or use_padding:
+ # reshape to 2D tensor
+ width = int(math.sqrt(patch_feature.size(1)))
+ patch_feature = patch_feature.view(-1, width, width,
+ patch_feature.size(-1))
+ # convert to NCHW
+ patch_feature = patch_feature.permute(0, 3, 1, 2)
+
+ if use_padding:
+ patch_feature = self.img_processor_padding(patch_feature)
+ if use_token_compression:
+ patch_feature = self.image_token_compression(patch_feature)
+
+ # convert to NHWC
+ patch_feature = patch_feature.permute(0, 2, 3, 1)
+ patch_feature = patch_feature.view(
+ -1,
+ patch_feature.size(1) * patch_feature.size(2),
+ patch_feature.size(-1))
+
+ return patch_feature
+
+ raise NotImplementedError
+
+ def forward(self, pixel_values: torch.FloatTensor,
+ image_sizes: torch.Tensor,
+ image_attention_mask: torch.Tensor) -> torch.FloatTensor:
+ """
+ process image and return vision embeddings.
+
+ pixel_values: (num_images, num_crops, c, h, w)
+ image_sizes: [[h1, w1], [h2, w2]]
+ image_attention_mask: num_images x num_crops x 32 x 32
+ output: (num_images, num_img_tokens, hidden_size)
+ """
+
+ # eg
+ # pixel_values: torch.Size([1, 7, 3, 448, 448])
+ # image_sizes: tensor([[ 896, 1344]], device='cuda:0')
+ # output: torch.Size([1, 1841, 3072])
+
+ if isinstance(self.img_projection, nn.Sequential):
+ target_device = self.img_projection[0].bias.device
+ target_dtype = self.img_projection[0].bias.dtype
+ else: # It's a single nn.Linear layer
+ target_device = self.img_projection.bias.device
+ target_dtype = self.img_projection.bias.dtype
+
+ img_sizes = image_sizes
+ num_images, num_crops, c, h, w = pixel_values.shape
+ bs = num_images
+ pixel_values = pixel_values.flatten(0, 1)
+
+ img_features = self.get_img_features(
+ pixel_values,
+ image_attention_mask.type(torch.BoolTensor).flatten(
+ 0, 1).to(target_device))
+
+ base_feat_height_target = self.base_feat_height_target
+ base_resolution = self.crop_size
+ base_feat_height_reduction = self.base_feat_height_reduction
+
+ base_feat_height = base_feat_width = int(np.sqrt(
+ img_features.shape[1]))
+ assert base_feat_height == base_feat_height_target \
+ and base_feat_width == base_feat_height_target, \
+ f'base_feat_height: {base_feat_height},"\
+ f" base_feat_width: {base_feat_width}, "\
+ f"expect {base_feat_height_target} features for hd transform'
+
+ # bs x max_num_crops x (24x24) x C
+ img_features = img_features.view(bs, -1,
+ base_feat_height * base_feat_width,
+ self.image_dim_out)
+ C = self.image_dim_out
+ H = base_feat_height
+
+ output_imgs = []
+ output_len = []
+ # training is tensor, inference is list
+ if isinstance(img_sizes, torch.Tensor):
+ img_sizes = img_sizes.view(-1, 2)
+ for _bs in range(bs):
+ h, w = img_sizes[_bs]
+ h = h // base_resolution
+ w = w // base_resolution
+ B_ = h * w
+
+ # 1 x (24x24) x 1024
+ global_img_feature = img_features[_bs, :1]
+
+ # 1 x 12 x 12 x 4096
+ glb_img = global_img_feature.reshape(1, H, H, C).reshape(
+ 1, H // base_feat_height_reduction, base_feat_height_reduction,
+ H // base_feat_height_reduction, base_feat_height_reduction,
+ C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape(
+ 1, H // base_feat_height_reduction,
+ H // base_feat_height_reduction,
+ base_feat_height_reduction * base_feat_height_reduction *
+ C).contiguous()
+ temp_glb_GN = self.sub_GN.repeat(1,
+ H // base_feat_height_reduction,
+ 1, 1)
+
+ # 1 x 156 x 4096
+ glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(
+ 1, -1,
+ base_feat_height_reduction * base_feat_height_reduction * C)
+
+ # (max_num_crops-1) x (12x12) x C
+ sub_img = img_features[_bs, 1:]
+ # 16x574x1024
+ # get rid of padding sub_img
+ sub_img = sub_img[:B_]
+
+ # (num_crops, 12, 2, 12, 2, 1024) ->
+ # (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024)
+ sub_img = sub_img.reshape(B_, H, H, C).reshape(
+ B_, H // base_feat_height_reduction,
+ base_feat_height_reduction, H // base_feat_height_reduction,
+ base_feat_height_reduction,
+ C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape(
+ B_, -1, base_feat_height_reduction *
+ base_feat_height_reduction * C).contiguous()
+ sub_img = sub_img.reshape(
+ 1, h, w, base_feat_height // base_feat_height_reduction,
+ base_feat_width // base_feat_height_reduction,
+ -1).permute(0, 1, 3, 2, 4, 5).reshape(
+ 1, h * base_feat_height // base_feat_height_reduction,
+ w * base_feat_width // base_feat_height_reduction,
+ base_feat_height_reduction * base_feat_height_reduction *
+ C)
+
+ if image_attention_mask is not None and len(
+ image_attention_mask) > 0:
+ reshaped_image_attention_mask = image_attention_mask[
+ _bs, 1:B_ + 1, 0::2, 0::2].reshape(
+ 1, h, w,
+ base_feat_height // base_feat_height_reduction,
+ base_feat_width // base_feat_height_reduction).permute(
+ 0, 1, 3, 2, 4).reshape(
+ 1, h * base_feat_height //
+ base_feat_height_reduction, w *
+ base_feat_width // base_feat_height_reduction)
+ useful_height = int(
+ reshaped_image_attention_mask[0, :, 0].sum().item())
+ useful_width = int(
+ reshaped_image_attention_mask[0, 0, :].sum().item())
+ sub_img = sub_img[:, :useful_height, :useful_width]
+ temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1)
+ temp_len = int(
+ image_attention_mask[_bs, :B_ + 1, 0::2, 0::2].sum().item(
+ )) + (useful_height +
+ 1) + base_feat_height // base_feat_height_reduction
+ else:
+ temp_sub_GN = self.sub_GN.repeat(
+ 1, h * base_feat_height // base_feat_height_reduction, 1,
+ 1)
+ temp_len = int((h * w + 1) * self.num_img_tokens + 1 +
+ (h + 1) * base_feat_height //
+ base_feat_height_reduction)
+
+ sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(
+ 1, -1,
+ base_feat_height_reduction * base_feat_height_reduction * C)
+ # (1, num_img_tokens, 1024*4)
+
+ # glb + sub
+ if self.hd_transform_order == 'glb_sub':
+ output_imgs.append(
+ torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
+ elif self.hd_transform_order == 'sub_glb':
+ output_imgs.append(
+ torch.cat([sub_img, self.glb_GN, glb_img], dim=1))
+ else:
+ raise NotImplementedError(
+ f'hd_transform_order = {self.hd_transform_order}, "\
+ "not implemented')
+
+ #temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
+ assert temp_len == output_imgs[-1].shape[
+ 1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: "\
+ "{output_imgs[-1].shape[1]}'
+
+ output_len.append(temp_len)
+
+ img_set_tensor = []
+ for _output_img in output_imgs:
+ img_feature_proj = self.img_projection(
+ _output_img.to(target_device).to(target_dtype))
+ img_set_tensor.append(img_feature_proj)
+
+ return img_set_tensor
+
+
+class Phi4MMAudioFeatureInputs(TypedDict):
+ type: Literal["audio_features"]
+ data: Tuple[NestedTensors]
+ """Shape: `((batch_size, num_audios, 80, M), )"""
+
+
+class Phi4MMAudioEmbeddingInputs(TypedDict):
+ type: Literal["audio_embeds"]
+ data: NestedTensors
+ """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
+
+
+Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs]
+
+
+def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
+ """Create a Mel filter-bank the same as SpeechLib FbankFC.
+
+ Args:
+ sample_rate (int): Sample rate in Hz. number > 0 [scalar]
+ n_fft (int): FFT size. int > 0 [scalar]
+ n_mel (int): Mel filter size. int > 0 [scalar]
+ fmin (float): lowest frequency (in Hz). If None use 0.0.
+ float >= 0 [scalar]
+ fmax: highest frequency (in Hz). If None use sample_rate / 2.
+ float >= 0 [scalar]
+
+ Returns
+ out (numpy.ndarray): Mel transform matrix
+ [shape=(n_mels, 1 + n_fft/2)]
+ """
+
+ bank_width = int(n_fft // 2 + 1)
+ if fmax is None:
+ fmax = sample_rate / 2
+ if fmin is None:
+ fmin = 0
+ assert fmin >= 0, "fmin cannot be negative"
+ assert (fmin < fmax <=
+ sample_rate / 2), "fmax must be between (fmin, samplerate / 2]"
+
+ def mel(f):
+ return 1127.0 * np.log(1.0 + f / 700.0)
+
+ def bin2mel(fft_bin):
+ return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
+
+ def f2bin(f):
+ return int((f * n_fft / sample_rate) + 0.5)
+
+ # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
+ klo = f2bin(fmin) + 1
+ khi = f2bin(fmax)
+
+ khi = max(khi, klo)
+
+ # Spec 2: SpeechLib uses triangles in Mel space
+ mlo = mel(fmin)
+ mhi = mel(fmax)
+ m_centers = np.linspace(mlo, mhi, n_mels + 2)
+ ms = (mhi - mlo) / (n_mels + 1)
+
+ matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
+ for m in range(0, n_mels):
+ left = m_centers[m]
+ center = m_centers[m + 1]
+ right = m_centers[m + 2]
+ for fft_bin in range(klo, khi):
+ mbin = bin2mel(fft_bin)
+ if left < mbin < right:
+ matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
+
+ return matrix
+
+
+class LogFbankProcessor:
+
+ def __init__(self):
+
+ self._eightk_method = "fillzero"
+ self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T
+
+ self._hamming400 = np.hamming(400) # for 16k audio
+ self._hamming200 = np.hamming(200) # for 8k audio
+
+ def extract_spectrogram(self, wav, fs):
+ """Extract spectrogram features from waveform.
+ Args:
+ wav (1D array): waveform of the input
+ fs (int): sampling rate of the waveform, 16000 or 8000.
+ If fs=8000, the waveform will be resampled to 16000Hz.
+ Output:
+ log_fbank (2D array): a TxD matrix of log Mel filterbank features.
+ D=80, and T is the number of frames.
+ """
+ if wav.ndim > 1:
+ wav = np.squeeze(wav)
+
+ # by default, we extract the mean if stereo
+ if len(wav.shape) == 2:
+ wav = wav.mean(1)
+
+ # Resample to 16000 or 8000 if needed
+ if fs > 16000:
+ wav = scipy.signal.resample_poly(wav, 1, fs // 16000)
+ fs = 16000
+ elif 8000 < fs < 16000:
+ wav = scipy.signal.resample_poly(wav, 1, fs // 8000)
+ fs = 8000
+ elif fs < 8000:
+ raise RuntimeError(f"Unsupported sample rate {fs}")
+
+ if fs == 8000:
+ if self._eightk_method == "resample":
+ # Input audio is 8 kHz. Convert to 16 kHz before feature
+ # extraction
+ wav = scipy.signal.resample_poly(wav, 2, 1)
+ fs = 16000
+ # Do nothing here for fillzero method
+ elif fs != 16000:
+ # Input audio is not a supported sample rate.
+ raise RuntimeError(
+ f"Input data using an unsupported sample rate: {fs}")
+
+ preemphasis = 0.97
+
+ if fs == 8000:
+ n_fft = 256
+ win_length = 200
+ hop_length = 80
+ fft_window = self._hamming200
+ elif fs == 16000:
+ n_fft = 512
+ win_length = 400
+ hop_length = 160
+ fft_window = self._hamming400
+
+ # Spec 1: SpeechLib cut remaining sample insufficient for a hop
+ n_batch = (wav.shape[0] - win_length) // hop_length + 1
+ # Here we don't use stride_tricks since the input array may not satisfy
+ # memory layout requirement and we need writeable output
+ # Here we only use list of views before copy to destination
+ # so it is more efficient than broadcasting
+ y_frames = np.array(
+ [
+ wav[_stride:_stride + win_length]
+ for _stride in range(0, hop_length * n_batch, hop_length)
+ ],
+ dtype=np.float32,
+ )
+
+ # Spec 2: SpeechLib applies preemphasis within each batch
+ y_frames_prev = np.roll(y_frames, 1, axis=1)
+ y_frames_prev[:, 0] = y_frames_prev[:, 1]
+ y_frames = (y_frames - preemphasis * y_frames_prev) * 32768
+
+ S = np.fft.rfft(fft_window * y_frames, n=n_fft,
+ axis=1).astype(np.complex64)
+
+ if fs == 8000:
+ # Need to pad the output to look like 16 kHz data but with zeros in
+ # the 4 to 8 kHz bins.
+ frames, bins = S.shape
+ padarray = np.zeros((frames, bins))
+ S = np.concatenate((S[:, 0:-1], padarray),
+ axis=1) # Nyquist bin gets set to zero
+
+ spec = np.abs(S).astype(np.float32)
+ return spec
+
+ def extract_features(self, wav, fs):
+ """Extract log filterbank features from waveform.
+ Args:
+ wav (1D array): waveform of the input
+ fs (int): sampling rate of the waveform, 16000 or 8000.
+ If fs=8000, the waveform will be resampled to 16000Hz.
+ Output:
+ log_fbank (2D array): a TxD matrix of log Mel filterbank features.
+ D=80, and T is the number of frames.
+ """
+ spec = self.extract_spectrogram(wav, fs)
+ spec_power = spec**2
+
+ fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
+ log_fbank = np.log(fbank_power).astype(np.float32)
+
+ return log_fbank
+
+
+@lru_cache
+def audio_feature_extractor() -> LogFbankProcessor:
+ # Creates an instance of the audio processor, needed to extract the
+ # the audio features from the sound file
+ # LRU cache ensures that we only make one copy
+ return LogFbankProcessor()
+
+
+def _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size,
+ vit_patch_size, token_compression_factor):
+ """
+ compute the number of tokens an image is expected to take up considering
+ the image encoder architecture and exclude output features containing
+ only padding pixels
+
+ for siglip, vit_image_size=448, vit_patch_size=14, so output will be
+ 32x32 feature map
+ NOTE right now, Phi4MM uses hard-coded token_compression_factor=2
+ """
+ assert vit_image_size % vit_patch_size == 0, \
+ "vit_image_size must be divisible by vit_patch_size"
+ assert vit_image_size // vit_patch_size % token_compression_factor == 0, \
+ "vit_image_size // vit_patch_size must be divisible by "\
+ "token_compression_factor"
+
+ target_aspect_ratio, target_height, target_width = (
+ _find_target_aspect_ratio(image,
+ vit_image_size,
+ dynamic_hd_size,
+ min_num=1))
+ assert target_aspect_ratio[
+ 0] * vit_image_size == target_width, \
+ f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}"
+ assert target_aspect_ratio[
+ 1] * vit_image_size == target_height, \
+ f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}"
+ assert (target_height % vit_image_size == 0
+ and target_width % vit_image_size == 0)
+
+ padding_height, padding_width = _get_padding_size(image, target_height,
+ target_width)
+ assert padding_width == 0 or padding_height == 0, \
+ "padding_width or padding_height must be 0"
+
+ target_feat_width = target_width // vit_patch_size
+ target_feat_height = target_height // vit_patch_size
+ if padding_width >= vit_patch_size:
+ assert padding_height == 0, "padding_height not 0"
+ non_pad_feat_width = target_feat_width - math.floor(
+ padding_width / vit_patch_size)
+ non_pad_feat_height = target_feat_height
+ elif padding_height >= vit_patch_size:
+ assert padding_width == 0, "padding_width not 0"
+ non_pad_feat_height = target_feat_height - math.floor(
+ padding_height / vit_patch_size)
+ non_pad_feat_width = target_feat_width
+ else:
+ # small padding shorter than a vit patch
+ non_pad_feat_width = target_feat_width
+ non_pad_feat_height = target_feat_height
+
+ feat_width = non_pad_feat_width // token_compression_factor
+ feat_height = non_pad_feat_height // token_compression_factor
+ # NOTE it's possible that the non-padding feature is not divisible
+ if non_pad_feat_width % token_compression_factor != 0:
+ feat_width += 1
+ if non_pad_feat_height % token_compression_factor != 0:
+ feat_height += 1
+ num_hd_patch_tokens = feat_width * feat_height
+ num_hd_newline_tokens = feat_height
+ vit_feature_size = vit_image_size // vit_patch_size
+ num_global_image_tokens = (vit_feature_size // token_compression_factor)**2
+ num_sep_tokens = 1
+ num_global_image_newline_tokens = \
+ vit_feature_size // token_compression_factor
+
+ return (num_global_image_tokens + num_sep_tokens + num_hd_patch_tokens +
+ num_hd_newline_tokens + num_global_image_newline_tokens)
+
+
+def compute_logfbank_output_size(wav_length: int, fs: int) -> Tuple[int, int]:
+ """
+ Compute the output size of the `extract_features` method.
+
+ Args:
+ wav_length (int): Length of the input waveform in samples.
+ fs (int): Sampling rate of the waveform, either 16000 or 8000.
+
+ Returns:
+ tuple (int, int): Output size as (T, D), where:
+ T: Number of time frames.
+ D: Number of Mel filterbank bins (80).
+ """
+
+ # Resample to 16000 or 8000 if needed
+ if fs > 16000:
+ wav_length //= fs // 16000
+ fs = 16000
+ elif 8000 <= fs < 16000:
+ # We'll resample to 16K from 8K
+ wav_length *= 2
+ fs = 16000
+ elif fs < 8000:
+ raise RuntimeError(f"Unsupported sample rate {fs}")
+
+ # Spectrogram parameters for 16 kHz
+ win_length = 400 # Frame length in samples
+ hop_length = 160 # Frame shift in samples
+ mel_bins = 80 # Number of mel filterbank bins
+
+ # Calculate number of frames (T)
+ T = (wav_length - win_length) // hop_length + 1
+ if T < 1:
+ raise ValueError("Waveform too short for given parameters.")
+
+ # Return time frames (T) and mel bins (D)
+ return T, mel_bins
+
+
+def _get_audio_embed_sizes(audios, ctx: InputContext):
+ """
+ Get the audio embedding sizes for each audio file.
+
+ Args:
+ audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of
+ waveform and sample rate.
+ ctx (InputContext): Input context.
+
+ Returns:
+ List[int]: List of audio embedding sizes.
+ """
+ audio_embed_sizes = []
+ for audio in audios:
+ audio_data, sf = audio
+ audio_frames, _ = compute_logfbank_output_size(len(audio_data), sf)
+ audio_embed_size = _compute_audio_embed_size(ctx.get_hf_config(),
+ audio_frames)
+ audio_embed_sizes.append(audio_embed_size)
+ return audio_embed_sizes
+
+
+def _get_audio_id_to_input_ids(audios, ctx: InputContext, prompt_str=""):
+ """
+ The following will search for `<|audio_{idx}|>` tokens and
+ return a mapping of audio placeholder tokens to audio placeholder token ids
+ based on the size of the audio embeddings.
+
+ Args:
+ audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of
+ waveform and sample rate.
+ ctx (InputContext): Input context.
+ prompt_str (str): The prompt string.
+
+ Returns:
+ Dict[str, List[int]]: Mapping of audio placeholder tokens to audio
+ placeholder token ids.
+
+ """
+ if len(audios) == 0:
+ return {}
+
+ audio_embed_sizes = _get_audio_embed_sizes(audios, ctx)
+ audio_ids = re.findall(AUDIO_TOKEN_PATTERN, prompt_str)
+ audio_ids = [int(audio_id) for audio_id in audio_ids]
+ assert len(audio_ids) == len(
+ audio_embed_sizes
+ ), "Number of audio tokens and audio features do not match"
+ assert tuple(audio_ids) == tuple(range(1,
+ len(audio_ids) +
+ 1)), "Audio ids are not in order!"
+ audio_id_to_input_ids = {
+ f"<|audio_{audio_id}|>":
+ [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size
+ for audio_id, audio_embed_size in zip(audio_ids, audio_embed_sizes)
+ }
+
+ return audio_id_to_input_ids
+
+
+def _count_image_tokens(images, ctx: InputContext):
+ hf_config = ctx.get_hf_config()
+ vision_encoder_name = hf_config.img_processor
+ if vision_encoder_name is None:
+ vision_encoder_name = SIGLIP_NAME
+ prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
+ dynamic_hd_size = prepro_config['dynamic_hd']
+ vit_image_size = prepro_config['vit_image_size']
+ vit_patch_size = prepro_config['vit_patch_size']
+ token_compression_factor = prepro_config['token_compression_factor']
+
+ image_token_counts = [
+ _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size,
+ vit_patch_size, token_compression_factor)
+ for image in images
+ ]
+ return image_token_counts
+
+
+def _get_image_id_to_input_ids(images, prompt, ctx: InputContext):
+ if len(images) == 0:
+ return {}
+
+ image_ids = re.findall(IMAGE_TOKEN_PATTERN, prompt)
+ image_ids = [int(image_id) for image_id in image_ids]
+ assert len(image_ids) == len(
+ set(image_ids)), "Duplicate image tokens in prompt"
+ assert len(images) == len(
+ image_ids), "Number of images and image tokens in prompt do not match"
+
+ # NOTE the following assertion is not strictly necessary
+ assert tuple(image_ids) == tuple(range(1,
+ len(image_ids) +
+ 1)), "Image ids are not in order"
+
+ image_token_counts = _count_image_tokens(images, ctx)
+ image_id_to_input_ids = {
+ f"<|image_{image_id}|>": [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_tokens
+ for image_id, num_tokens in zip(image_ids, image_token_counts)
+ }
+ return image_id_to_input_ids
+
+
+def input_processor_for_phi4mm(ctx: InputContext,
+ inputs: DecoderOnlyInputs) -> TokenInputs:
+ """
+ Implements the input processor, which transforms the input prompt ids
+ to include the audio placeholder token. This will become the `input_ids`
+ in `forward` for the model.
+
+ Args:
+ ctx (InputContext): Input context.
+ inputs (DecoderOnlyInputs): The inputs (e.g. prompt, prompt_token_ids)
+ to process.
+
+ Returns:
+ TokenInputs: Processed inputs
+ """
+ multi_modal_data = inputs.get("multi_modal_data")
+ if (multi_modal_data is None or
+ ("audio" not in multi_modal_data and "image" not in multi_modal_data)):
+ # pure text input, so no need to do pre-processing
+ return inputs
+
+ prompt_str = inputs.get("prompt")
+ prompt_token_ids = inputs.get("prompt_token_ids")
+ # for offline_inference, we will get str input and we parse MM special
+ # tokens from it
+ # (ignore prompt_token_ids)
+ # for OAI server, we will get prompt_token_ids, where MM special tokens
+ # are already parsed
+
+ if 'audio' in multi_modal_data:
+ audios = multi_modal_data["audio"]
+
+ if not isinstance(audios, list):
+ audios = [audios]
+ if prompt_str is not None:
+ audio_id_to_input_ids = _get_audio_id_to_input_ids(
+ audios, ctx, prompt_str=prompt_str)
+ audio_embed_sizes = []
+ elif prompt_token_ids is not None:
+ audio_id_to_input_ids = {}
+ audio_embed_sizes = _get_audio_embed_sizes(audios, ctx)
+ else:
+ audio_id_to_input_ids = {}
+ audio_embed_sizes = []
+
+ if 'image' in multi_modal_data:
+ # PIL Image or list of PIL Images
+ images = multi_modal_data["image"]
+ if not isinstance(images, list):
+ images = [images]
+ if prompt_str is not None:
+ image_id_to_input_ids = _get_image_id_to_input_ids(
+ images, prompt_str, ctx)
+ image_token_counts = []
+ elif prompt_token_ids is not None:
+ image_id_to_input_ids = {}
+ image_token_counts = _count_image_tokens(images, ctx)
+ else:
+ image_id_to_input_ids = {}
+ image_token_counts = []
+
+ # Handle the case where the prompt is a string and we need to manually
+ # tokenize it.
+ # In this case, the `audio_id_to_input_ids` dict will be mapping from
+ # an audio placeholder
+ # string (e.g. `<|audio_1|>`) to the audio placeholder tokens for the
+ # given audio length.
+ if prompt_str:
+ pattern = r"(<\|image_\d+\|>|<\|audio_\d+\|>)"
+ prompt_chunk_strings = re.split(pattern, prompt_str)
+ prompt_chunk_strings = [s for s in prompt_chunk_strings if s != ""]
+
+ # Create the new input_ids with the placeholder image and audio
+ # tokens inserted
+ tokenizer = cached_tokenizer_from_config(ctx.model_config)
+ input_ids = []
+ has_imag, has_audio, has_user_text_input = False, False, False
+ for prompt_chunk_string in prompt_chunk_strings:
+ if re.match(IMAGE_TOKEN_PATTERN, prompt_chunk_string):
+ input_ids.extend(image_id_to_input_ids[prompt_chunk_string])
+ has_imag = True
+ elif re.match(AUDIO_TOKEN_PATTERN, prompt_chunk_string):
+ input_ids.extend(audio_id_to_input_ids[prompt_chunk_string])
+ has_audio = True
+ else:
+ curr_token_ids = tokenizer(prompt_chunk_string).input_ids
+ if not has_user_text_input:
+ for token_id in curr_token_ids:
+ if token_id not in NON_USER_INPUT_TOKENS:
+ has_user_text_input = True
+ break
+ input_ids.extend(curr_token_ids)
+ if has_audio and has_imag and has_user_text_input:
+ raise ValueError(
+ "Phi4MMForCausalLM does not support text + audio + image" +
+ " inputs in the same prompt")
+ # Handle the case where the prompt is already tokenized
+ else:
+ assert prompt_token_ids is not None, \
+ "If string prompt isn't provided, prompt_token_ids must be"
+
+ i = 0
+ input_ids = prompt_token_ids
+ # only needed for later assertion
+ img_cnt, audio_cnt, user_text_input_cnt = 0, 0, 0
+ image_token_count_iter = iter(image_token_counts)
+ audio_embed_size_iter = iter(audio_embed_sizes)
+ while i < len(input_ids):
+ token_id = input_ids[i]
+ if token_id == _AUDIO_PLACEHOLDER_TOKEN_ID:
+ token_count = next(audio_embed_size_iter)
+ audio_cnt += 1
+ elif token_id == _IMAGE_PLACEHOLDER_TOKEN_ID:
+ token_count = next(image_token_count_iter)
+ img_cnt += 1
+ else:
+ user_text_input_cnt += 1 if token_id not in \
+ NON_USER_INPUT_TOKENS else 0
+ i += 1
+ continue
+ tokens = [token_id] * token_count
+ input_ids = input_ids[:i] + tokens + input_ids[i + 1:]
+ i += token_count
+
+ if audio_cnt > 0 and img_cnt > 0 and user_text_input_cnt > 0:
+ raise ValueError(
+ "Phi4MMForCausalLM does not support text + audio + image" +
+ " inputs in the same prompt")
+ # If the below assertion fails, it might be that input pure-text
+ # messages contain image/audio special tokens literally
+ # (<|endoftext10|>, <|endoftext11|>).
+ assert (img_cnt == len(image_token_counts)), (
+ f"Number of image tokens in prompt_token_ids ({img_cnt}) "
+ f"does not match number of images ({len(image_token_counts)})")
+ assert (audio_cnt == len(audio_embed_sizes)), (
+ f"Number of audio tokens in prompt_token_ids ({audio_cnt}) "
+ f"does not match number of audios ({len(audio_embed_sizes)})")
+
+ # NOTE: Create a defensive copy of the original inputs
+ return token_inputs(
+ prompt_token_ids=input_ids,
+ prompt=prompt_str,
+ multi_modal_data=multi_modal_data,
+ )
+
+
+def _compute_audio_embed_size(hf_config, audio_frames):
+ """
+ Compute the audio embedding size based on the audio frames and
+ compression rate.
+ """
+ compression_rate = hf_config.embd_layer['audio_embd_layer'][
+ 'compression_rate']
+ # NOTE: this is a hard-coded value but might be configurable in the future
+ qformer_compression_rate = 1
+ integer = audio_frames // compression_rate
+ remainder = audio_frames % compression_rate
+
+ result = integer if remainder == 0 else integer + 1
+
+ integer = result // qformer_compression_rate
+ remainder = result % qformer_compression_rate
+ result = integer if remainder == 0 else integer + 1 # qformer compression
+
+ return result
+
+
+def get_max_phi4mm_audio_tokens(ctx: InputContext) -> int:
+ return 10000
+
+
+def dummy_audio_for_phi4mm(audio_count: int) -> dict:
+ """
+ Create dummy audio data for the Phi4MM model, which is used for profiling.
+
+ Args:
+ audio_count (int): Number of audio samples.
+
+ Returns:
+ dict: Dummy audio data.
+ """
+ dummy_audio = np.full((_AUDIO_MAX_SOUNDFILE_SIZE, ), 0.0)
+ return [(dummy_audio, DUMMY_SAMPLING_FREQUENCY)] * audio_count
+
+
+def dummy_image_for_phi4mm(width: int, height: int):
+ image = Image.new('RGB', (width, height), color='black')
+ return image
+
+
+def dummy_data_for_phi4mm(ctx: InputContext, seq_len: int,
+ mm_counts: Mapping[str, int]) -> DummyData:
+ """
+ Create dummy sequence (input_ids) and audio data for the Phi4MM model,
+ which is used for profiling.
+
+ In this case, the sequence data is a bunch of 0s with a number of audio
+ tokens that correspond to the audio embed size of the
+ _AUDIO_MAX_SOUNDFILE_SIZE.
+
+ Args:
+ ctx (InputContext): Input context.
+ seq_len (int): Length of the sequence.
+ mm_counts (Mapping[str, int]): Multi-modal counts.
+
+ Returns:
+ Tuple: Dummy sequence data and dummy audio data.
+ """
+ audio_count = mm_counts["audio"]
+ audio_frames, _ = compute_logfbank_output_size(_AUDIO_MAX_SOUNDFILE_SIZE,
+ DUMMY_SAMPLING_FREQUENCY)
+ audio_feature_size = _compute_audio_embed_size(ctx.get_hf_config(),
+ audio_frames)
+
+ image_count = mm_counts["image"]
+ dummy_image = get_max_dummy_image(ctx)
+ max_image_tokens = get_max_phi4mm_image_tokens(ctx)
+ total_image_tokens = image_count * max_image_tokens
+
+ if seq_len - audio_feature_size * audio_count - total_image_tokens < 0:
+ raise RuntimeError(
+ f"Phi4MM cannot process {audio_count} audios and {image_count}"
+ f"images in a prompt, please increase max_model_len to be at"
+ f" larger than "
+ f"{audio_feature_size * audio_count + total_image_tokens}"
+ " or reduce audio/image limit by --limit-mm-per-prompt.")
+
+ if audio_feature_size * audio_count > total_image_tokens:
+ seq_data = SequenceData.from_prompt_token_counts(
+ (_AUDIO_PLACEHOLDER_TOKEN_ID, audio_feature_size * audio_count),
+ (0, seq_len - audio_feature_size * audio_count),
+ )
+ mm_data = {
+ "audio": dummy_audio_for_phi4mm(audio_count),
+ }
+ else:
+ seq_data = SequenceData.from_prompt_token_counts(
+ (_IMAGE_PLACEHOLDER_TOKEN_ID, total_image_tokens),
+ (0, seq_len - total_image_tokens),
+ )
+ mm_data = {
+ "image": [dummy_image] * image_count,
+ }
+ return DummyData(seq_data, mm_data)
+
+
+def input_mapper_for_phi4mm_audio(ctx: InputContext,
+ data: object) -> MultiModalInputs:
+ """
+ This function is used to create the MultiModalInputs for the Phi4MM
+ (audio) model.
+ Specifically, for audio, we extract the audio features from the sound
+ file and create pairs of audio features and audio embed lengths (the
+ latter of which is used to repeat the audio placeholder token in the
+ input prompt IDs).
+ These pairs are used, downstream, in `_audio_features_to_embeddings`
+ (via `_process_audio_input`).
+
+ Note that the incoming audio data (each entry in `data`) is a tuple of
+ the audio data and the sampling frequency (e.g. from soundfile.read).
+
+ Args:
+ ctx (InputContext): Input context.
+ data (object): Audio data.
+
+ Returns:
+ MultiModalInputs: Multi-modal inputs.
+ """
+ if not isinstance(data, list):
+ data = [data]
+
+ if len(data) == 0:
+ return MultiModalInputs()
+
+ audio_features = []
+ for audio_input in data:
+ if not isinstance(audio_input, tuple):
+ raise NotImplementedError(
+ f"Unsupported data type: {type(audio_input)}")
+
+ audio, sf = audio_input
+ feature_extractor = audio_feature_extractor()
+ single_audio_features = feature_extractor.extract_features(audio, sf)
+ feat_stride = (1 if not hasattr(feature_extractor, "stride") else
+ feature_extractor.stride)
+ audio_frames = len(single_audio_features) * feat_stride
+ single_audio_embed_size = _compute_audio_embed_size(
+ ctx.get_hf_config(), audio_frames)
+ single_audio_feature_audio_len_pair = (
+ single_audio_features,
+ [single_audio_embed_size],
+ )
+ audio_features.append(single_audio_feature_audio_len_pair)
+ return MultiModalInputs({"audio_features": audio_features})
+
+
+def input_mapper_for_phi4mm_image(ctx: InputContext, data: object):
+ if not isinstance(data, list):
+ data = [data]
+ # data: list of PIL images
+ if len(data) == 0:
+ return MultiModalInputs()
+ hf_config = ctx.get_hf_config()
+ vision_encoder_name = hf_config.img_processor
+ if vision_encoder_name is None:
+ vision_encoder_name = SIGLIP_NAME
+ prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
+ dynamic_hd_size = prepro_config['dynamic_hd']
+ vit_image_size = prepro_config['vit_image_size']
+ vit_patch_size = prepro_config['vit_patch_size']
+
+ image_input_dict = preprocess(data, dynamic_hd_size, vit_image_size,
+ vit_patch_size)
+ return MultiModalInputs({
+ "pixel_values":
+ image_input_dict["pixel_values"],
+ "image_sizes":
+ image_input_dict["image_sizes"],
+ "image_attention_mask":
+ image_input_dict["image_attention_mask"],
+ "num_img_tokens":
+ image_input_dict["num_img_tokens"],
+ })
+
+
+def cat_with_pad(tensors, dim, padding_value=0):
+ """
+ cat along dim, while pad to max for all other dims
+ """
+ ndim = tensors[0].dim()
+ assert all(
+ t.dim() == ndim for t in
+ tensors[1:]), "All tensors must have the same number of dimensions"
+
+ out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
+ out_size[dim] = sum(t.shape[dim] for t in tensors)
+ output = tensors[0].new_full(out_size, padding_value)
+
+ index = 0
+ for t in tensors:
+ # Create a slice list where every dimension except dim is full slice
+ slices = [slice(0, t.shape[d]) for d in range(ndim)]
+ # Update only the concat dimension slice
+ slices[dim] = slice(index, index + t.shape[dim])
+
+ output[slices] = t
+ index += t.shape[dim]
+
+ return output
+
+
+@MULTIMODAL_REGISTRY.register_input_mapper("audio",
+ input_mapper_for_phi4mm_audio)
+@MULTIMODAL_REGISTRY.register_input_mapper("image",
+ input_mapper_for_phi4mm_image)
+@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
+ "audio", get_max_phi4mm_audio_tokens)
+@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
+ "image", get_max_phi4mm_image_tokens)
+@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi4mm)
+@INPUT_REGISTRY.register_input_processor(input_processor_for_phi4mm)
+class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
+ """
+ Implements the Phi-4-multimodal-instruct model in VLLM.
+ """
+ # LoRA specific attributes
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "qkv_proj",
+ ],
+ "gate_up_proj": [
+ "gate_up_proj",
+ ],
+ }
+ supported_lora_modules = [
+ "qkv_proj", "o_proj", "gate_up_proj", "down_proj"
+ ]
+ # Phi4MMForCausalLM does not apply LoRA to the embedding layer.
+ embedding_modules = {}
+ embedding_padding_modules = []
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+ assert multimodal_config, "multimodal_config is required"
+ quant_config = vllm_config.quant_config
+ lora_config = vllm_config.lora_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+ self.quant_config = quant_config
+ self.lora_config = lora_config
+
+ # Tensor/Pipeline parallel not supported for now.
+ assert get_tensor_model_parallel_world_size(
+ ) == 1, "tensor parallel is not supported"
+ assert get_pp_group(
+ ).world_size == 1, "pipeline parallel is not supported"
+
+ self.vision_encoder = Phi4MMImageEncoder(
+ config,
+ quant_config,
+ prefix="model.vision_embed_tokens",
+ model_dir=config._name_or_path)
+
+ if isinstance(config.embd_layer["audio_embd_layer"], dict):
+ embedding_config = {
+ "embedding_cls":
+ config.embd_layer["audio_embd_layer"]["embedding_cls"],
+ **config.embd_layer["audio_embd_layer"],
+ }
+ else:
+ embedding_config = {
+ "embedding_cls": self.config.embd_layer["embedding_cls"]
+ }
+
+ self.embed_tokens_extend = AudioEmbedding(config, **embedding_config)
+ self.model = LlamaModel(vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"))
+
+ self.unpadded_vocab_size = config.vocab_size
+ if lora_config:
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
+ self.lm_head = ParallelLMHead(
+ self.unpadded_vocab_size,
+ config.hidden_size,
+ org_num_embeddings=config.vocab_size,
+ padding_size=(
+ DEFAULT_VOCAB_PADDING_SIZE
+ # We need bigger padding if using lora for kernel
+ # compatibility
+ if not lora_config else lora_config.lora_vocab_padding_size),
+ quant_config=quant_config,
+ )
+ if config.tie_word_embeddings:
+ self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
+ logit_scale = getattr(config, "logit_scale", 1.0)
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+ config.vocab_size, logit_scale)
+ self.sampler = Sampler()
+
+ def _audio_features_to_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ input_features: List[torch.Tensor],
+ audio_input_sizes: torch.Tensor,
+ audio_projection_mode: str,
+ ) -> torch.Tensor:
+ """
+ Convert audio features to embeddings, which are used as input to the
+ model (via `inputs_embeds`).
+
+ Args:
+ input_ids (torch.Tensor): Input IDs (the prompt in this case).
+ input_features (list[torch.Tensor]): Input features (the audio
+ embeddings).
+ audio_input_sizes (list[torch.Tensor]): Audio input sizes (the
+ audio embed lengths to use for padding the audio placeholder token
+ in the input prompt IDs).
+ """
+ # The audio projection can either be a single linear or Sequential,
+ # so handle both cases
+ if isinstance(self.embed_tokens_extend.audio_projection,
+ nn.Sequential):
+ target_dtype = self.embed_tokens_extend.audio_projection[
+ 0].bias.dtype
+ else:
+ target_dtype = self.embed_tokens_extend.audio_projection.bias.dtype
+
+ audio_input = [
+ input.unsqueeze(0).to(target_dtype) for input in input_features
+ ]
+ kwargs = {
+ "wte": self.model.embed_tokens,
+ 'audio_projection_mode': audio_projection_mode
+ }
+ audio_embeddings = self.embed_tokens_extend(input_ids, audio_input,
+ audio_input_sizes,
+ **kwargs)
+ audio_embeddings = audio_embeddings.to(target_dtype)
+ return audio_embeddings
+
+ def _parse_and_validate_audio_input(
+ self, **kwargs: object) -> Optional[Phi4MMAudioInputs]:
+ """
+ Parse and validate the audio input to the model. This handles both
+ audio features and audio embeddings, but only the former is used for
+ now.
+
+ Args:
+ kwargs (object): Keyword arguments.
+
+ Returns:
+ Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs.
+ """
+ audio_features = kwargs.pop("audio_features", None)
+ audio_embeds = kwargs.pop("audio_embeds", None)
+
+ if audio_features is None and audio_embeds is None:
+ return None
+
+ if audio_features is not None:
+ if not isinstance(audio_features, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of audio features. "
+ f"Got type: {type(audio_features)}")
+
+ return Phi4MMAudioFeatureInputs(type="audio_features",
+ data=audio_features)
+
+ if audio_embeds is not None:
+ if not isinstance(audio_embeds, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of audio embeds. "
+ f"Got type: {type(audio_embeds)}")
+
+ return Phi4MMAudioEmbeddingInputs(type="audio_embeds",
+ data=audio_embeds)
+
+ raise AssertionError("This line should be unreachable.")
+
+ def _process_audio_input(self, input_ids: torch.Tensor,
+ audio_input: Phi4MMAudioInputs,
+ audio_projection_mode: str) -> NestedTensors:
+ """
+ Create the audio embeddings from the audio input, where the audio input
+ is pairs of audio features and audio embed lengths. The audio input is
+ created by `input_mapper_for_phi4mm_audio`.
+
+ Args:
+ input_ids (torch.Tensor): Input IDs (the prompt in this case,
+ before the audio token replication).
+ audio_input (Phi4MMAudioInputs): Audio input.
+
+ Returns:
+ NestedTensors: Audio embeddings
+ """
+ if audio_input["type"] == "audio_embeds":
+ return audio_input["data"]
+
+ audio_features = audio_input["data"]
+ # (e.g. multiple examples) and the second dim is the multi-audio dim
+ # (e.g. multiple audios in the same example)
+ audio_feature = [i[0] for j in audio_features for i in j]
+ audio_feature_len = [i[1].item() for j in audio_features for i in j]
+ # Add the batch dim via `squeeze`
+
+ return self._audio_features_to_embeddings(
+ input_ids.unsqueeze(0),
+ audio_feature,
+ audio_feature_len,
+ audio_projection_mode,
+ ).squeeze(0)
+
+ def _parse_and_validate_image_input(self,
+ **kwargs: object) -> Optional[Dict]:
+ pixel_values: Optional[Dict] = kwargs.get("pixel_values")
+ if pixel_values is None:
+ return None
+
+ image_sizes = kwargs.get("image_sizes")
+ image_attention_mask = kwargs.get("image_attention_mask")
+ num_img_tokens = kwargs.get("num_img_tokens")
+ assert image_sizes is not None and image_attention_mask is not None\
+ and num_img_tokens is not None, "Missing image inputs"
+
+ if isinstance(pixel_values, list):
+ assert pixel_values[0].dim() == 5, "Incorrect image inputs"
+ # list len is batch_size.
+ # each tensor has dimension: num_img_per_example, num_hd_patches,
+ # channels, height, width.
+ # need to pad along num_hd_patches.
+ # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w.
+ pixel_values = cat_with_pad(pixel_values, dim=0)
+ elif isinstance(pixel_values, torch.Tensor):
+ # dimension: batch_size, num_img_per_example, num_hd_patches,
+ # channels, height, width.
+ # we flatten first 2 dims to make it a single large batch for
+ # SigLIP Encoder.
+ assert pixel_values.dim() == 6, "Incorrect image inputs"
+ pixel_values = pixel_values.flatten(0, 1)
+ else:
+ raise ValueError("Incorrect pixel_values inputs")
+
+ if isinstance(image_attention_mask, list):
+ image_attention_mask = cat_with_pad(image_attention_mask, dim=0)
+ elif isinstance(image_attention_mask, torch.Tensor):
+ image_attention_mask = image_attention_mask.flatten(0, 1)
+ else:
+ raise ValueError("Incorrect image_attention_mask inputs")
+
+ if isinstance(image_sizes, list):
+ image_sizes = torch.cat(image_sizes, dim=0)
+ elif isinstance(image_sizes, torch.Tensor):
+ image_sizes = image_sizes.flatten(0, 1)
+ else:
+ raise ValueError("Incorrect image_attention_mask inputs")
+
+ if isinstance(num_img_tokens, list):
+ num_img_tokens = [
+ n for num_tensor in num_img_tokens
+ for n in num_tensor.tolist()
+ ]
+ elif isinstance(num_img_tokens, torch.Tensor):
+ num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
+ else:
+ raise ValueError("Incorrect image_attention_mask inputs")
+
+ return {
+ 'pixel_values': pixel_values,
+ 'image_sizes': image_sizes,
+ 'image_attention_mask': image_attention_mask,
+ 'num_img_tokens': num_img_tokens,
+ }
+
+ def merge_image_features_to_inputs_embeds(
+ self,
+ input_ids: torch.Tensor,
+ inputs_embeds: torch.Tensor,
+ image_set_tensors: List[torch.Tensor],
+ ):
+ position_tuple = (input_ids == _IMAGE_PLACEHOLDER_TOKEN_ID).nonzero(
+ as_tuple=True)
+
+ assert all([t.shape[0] == 1 for t in image_set_tensors
+ ]), 'img_set_tensor should have shape (1, N_tokens, C)'
+ # Shape: (merged_N_tokens, C)
+ image_set_tensor = torch.cat(image_set_tensors, dim=1).squeeze(0)
+ image_set_tensor = image_set_tensor.to(inputs_embeds.dtype).to(
+ inputs_embeds.device)
+ merged_embeds = inputs_embeds.index_put(
+ indices=position_tuple,
+ values=image_set_tensor,
+ accumulate=False,
+ )
+ return merged_embeds
+
+ def load_weights(self, weights: Iterable[Tuple[str,
+ torch.Tensor]]) -> None:
+ weights = {name: weight for name, weight in weights}
+ adjusted_weights = {}
+
+ for name, weight in weights.items():
+ # NOTE vision-speech tasks use a separate projection layer
+ audio_proj_4v = \
+ "model.embed_tokens_extend.audio_embed.audio_projection.vision"
+ if name.startswith(audio_proj_4v):
+ name = name.replace(
+ audio_proj_4v,
+ "embed_tokens_extend.audio_projection_for_vision")
+
+ name = (name.replace(
+ "model.embed_tokens_extend.audio_embed."\
+ "audio_projection.speech.",
+ "embed_tokens_extend.audio_projection.",
+ ).replace(
+ "model.embed_tokens_extend.audio_embed.",
+ "embed_tokens_extend.",
+ ).replace("model.embed_tokens_extend.image_embed.",
+ "vision_encoder."))
+ # NOTE: this is deal with LoRA injection, where `base_layer`
+ # remains as the original layer in the model
+ if name.endswith(".base_layer.weight"):
+ name = name.replace(".base_layer.weight", ".weight")
+ adjusted_weights[name] = weight
+
+ missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
+ strict=False)
+ logger.debug("*** missing keys:")
+ for key in missing_keys:
+ logger.debug(key)
+ logger.debug("**** unexpected keys:")
+ for key in unexpected_keys:
+ logger.debug(key)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ **kwargs: object,
+ ) -> torch.Tensor:
+ if intermediate_tensors is not None:
+ input_ids = None
+ inputs_embeds = None
+ else:
+ # Each entry in this is a pair of audio_features and audio_embed
+ # lengths
+ audio_input = self._parse_and_validate_audio_input(**kwargs)
+ image_inputs = self._parse_and_validate_image_input(**kwargs)
+
+ has_audio = audio_input is not None
+ has_image = image_inputs is not None
+
+ if has_audio:
+ audio_projection_mode = 'vision' if has_image else 'speech'
+ inputs_embeds = self._process_audio_input(
+ input_ids, audio_input, audio_projection_mode)
+
+ if has_image:
+ dtype = self.vision_encoder.img_processor.embeddings.\
+ patch_embedding.weight.dtype
+ pixel_values = image_inputs['pixel_values'].to(dtype)
+ image_sizes = image_inputs['image_sizes']
+ image_attention_mask = image_inputs['image_attention_mask']
+ image_set_tensors = self.vision_encoder(
+ pixel_values, image_sizes, image_attention_mask)
+ if not has_audio:
+ inputs_embeds = self.model.embed_tokens(input_ids)
+
+ inputs_embeds = self.merge_image_features_to_inputs_embeds(
+ input_ids, inputs_embeds, image_set_tensors)
+
+ if has_image or has_audio:
+ # multi-modal input, we have set inputs_embeds properly in
+ # previous steps
+ input_ids = None
+ else:
+ # text-only, we keep using original input_ids
+ inputs_embeds = None
+
+ hidden_states = self.model(
+ input_ids,
+ positions,
+ intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
+
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ logits = self.logits_processor(self.lm_head, hidden_states,
+ sampling_metadata)
+ return logits
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ next_tokens = self.sampler(logits, sampling_metadata)
+ return next_tokens
diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py
new file mode 100644
index 0000000000000..f9d4881c55e29
--- /dev/null
+++ b/vllm/model_executor/models/phi4mm_audio.py
@@ -0,0 +1,1403 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+# Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com)
+# but implemented by the Phi-Speech team
+#!/usr/bin/env python3
+import abc
+import math
+from functools import partial
+from typing import Callable, Dict, List, Literal, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
+ CheckpointImpl, CheckpointWrapper, checkpoint_wrapper, offload_wrapper)
+from torch.distributed.fsdp.fully_sharded_data_parallel import (
+ FullyShardedDataParallel)
+from torch.utils.checkpoint import checkpoint
+from transformers import PretrainedConfig
+
+from vllm.model_executor.models.phi4mm_utils import (
+ AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer,
+ MultiHeadedAttention, NemoConvSubsampling, T5RelativeAttentionLogitBias,
+ adaptive_enc_mask, attn_checkpointing, embedding_checkpoint_wrapper,
+ get_offset, repeat, unfold_tensor, validate_checkpointing_config)
+
+_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|>
+
+
+def encoder_checkpoint_wrapper(
+ activation_checkpointing: Union[str, Dict],
+ layer_cls: type,
+ idx: int = 0,
+) -> Callable:
+ """return encoder activation checkpoint wrapper"""
+ validate_checkpointing_config(activation_checkpointing)
+
+ if isinstance(activation_checkpointing, str):
+ if activation_checkpointing:
+ if activation_checkpointing == "offload":
+ return offload_wrapper
+ return partial(checkpoint_wrapper)
+ return lambda x: x
+
+ if isinstance(activation_checkpointing, dict):
+ target_layer_cls = activation_checkpointing.get(
+ "module", "transformer")
+ if target_layer_cls.lower() == "transformer":
+ target_layer_cls = (
+ "EncoderLayer",
+ "ConformerEncoderLayer",
+ )
+ elif target_layer_cls.lower() == "attention":
+ target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention")
+ checkpointing_interval = activation_checkpointing.get("interval", 1)
+ offloading = activation_checkpointing.get("offload", False)
+ impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get(
+ "reentrant", True) else CheckpointImpl.NO_REENTRANT)
+
+ if (idx % checkpointing_interval == 0
+ and layer_cls.__name__ in target_layer_cls):
+ if offloading:
+ return offload_wrapper
+ return partial(checkpoint_wrapper, checkpoint_impl=impl)
+ return lambda x: x
+
+ raise ValueError("Invalid activation_checkpointing config")
+
+
+class ConformerEncoderLayer(nn.Module):
+ """ConformerEncoder Layer module.
+ for more details see conformer paper:
+ https://arxiv.org/abs/2005.08100
+ This module implement the Conformer block layer.
+
+ Args:
+ d_model: int
+ attention dim.
+ ext_pw_out_channel: int
+ if > 0, ext_pw_out_channel is a dim channel size
+ for the last pointwise conv after swish activation.
+ depthwise_seperable_out_channel: int
+ if set different to 0, the number of
+ depthwise_seperable_out_channel will be used as a
+ channel_out of the second conv1d layer.
+ otherwise, it equal to 0, the second conv1d layer is skipped.
+ depthwise_multiplier: int
+ number of input_dim channels duplication. this value
+ will be used to compute the hidden channels of the Conv1D.
+ n_head: int
+ the number of heads for multihead attention module.
+ d_ffn: int
+ output size of the feed_forward blocks.
+ ext_pw_kernel_size: int
+ kernel size of the conv pointwise of the conformer.
+ kernel_size: int
+ kernel size.
+ dropout_rate: float
+ dropout rate.
+ causal: bool, optional
+ if set to True, convolution have no access
+ to future frames. default False.
+ batch_norm: bool, optional
+ if set to True, apply batchnorm before activation
+ in ConvModule layer of the conformer.
+ default False
+ activation: str, optional
+ activation function name,
+ one of ["relu", "swish", "sigmoid"],
+ sigmoid activation is only used with "glu_in_fnn=True",
+ default "relu".
+ chunk_se: int, optional
+ 0 for offline SE.
+ 1 for streaming SE, where mean is computed
+ by accumulated history until current chunk_se.
+ 2 for streaming SE, where mean is computed
+ by only the current chunk.
+ default 0.
+ chunk_size: int, optional
+ chunk_size for cnn. default 18
+ conv_activation: str, optional
+ activation function used in ConvModule part
+ of the conformer, default "relu".
+ conv_glu_type: str, optional
+ activation function used for the glu inside
+ the ConvModule part of the conformer.
+ default: "sigmoid".
+ bias_in_glu: bool, optional
+ if set to True, use additive bias in the weight module
+ before GLU.
+ linear_glu_in_convm: bool, optional
+ if set to True, use GLULinear module,
+ otherwise, used GLUPointWiseConv module.
+ default to False.
+ attention_innner_dim: int, optional
+ if equal to -1, attention dim for linears k/q/v is
+ equal to d_model. otherwise attention_innner_dim is used.
+ default -1.
+ attention_glu_type: str, optional
+ activation function for glu used in the multihead attention,
+ default "swish".
+ activation_checkpointing: str, optional
+ a dictionarry of {"module","interval","offload"}, where
+ "module": str
+ accept ["transformer", "attention"] to select
+ which module should do activation checkpointing.
+ "interval": int, default 1,
+ interval of applying activation checkpointing,
+ interval = 1 means that we apply checkpointing
+ on every layer (if activation), otherwise,
+ we apply it every x interval.
+ "offload": bool, default False,
+ if set to True, we offload activation to cpu and
+ reload it during backward, otherwise,
+ we recalculate activation in backward.
+ default "".
+ export: bool, optional
+ if set to True, it remove the padding from convolutional layers
+ and allow the onnx conversion for inference.
+ default False.
+ use_pt_scaled_dot_product_attention: bool, optional
+ if set to True, use pytorch's scaled dot product attention
+ implementation in training.
+ attn_group_sizes: int, optional
+ the number of groups to use for attention, default 1
+ (Multi-Head Attention),
+ 1 = typical Multi-Head Attention,
+ 1 < attn_group_sizes < attention_heads = Grouped-Query Attention
+ attn_group_sizes = attenion_heads = Multi-Query Attention
+ """
+
+ def __init__(
+ self,
+ d_model=512,
+ ext_pw_out_channel=0,
+ depthwise_seperable_out_channel=256,
+ depthwise_multiplier=1,
+ n_head=4,
+ d_ffn=2048,
+ ext_pw_kernel_size=1,
+ kernel_size=3,
+ dropout_rate=0.1,
+ causal=False,
+ batch_norm=False,
+ activation="relu",
+ chunk_se=0,
+ chunk_size=18,
+ conv_activation="relu",
+ conv_glu_type="sigmoid",
+ bias_in_glu=True,
+ linear_glu_in_convm=False,
+ attention_innner_dim=-1,
+ attention_glu_type="swish",
+ activation_checkpointing="",
+ export=False,
+ use_pt_scaled_dot_product_attention=False,
+ attn_group_sizes: int = 1,
+ ):
+ super().__init__()
+
+ self.feed_forward_in = FeedForward(
+ d_model=d_model,
+ d_inner=d_ffn,
+ dropout_rate=dropout_rate,
+ activation=activation,
+ bias_in_glu=bias_in_glu,
+ )
+
+ self.self_attn = encoder_checkpoint_wrapper(
+ activation_checkpointing,
+ MultiHeadedAttention,
+ )(MultiHeadedAttention(
+ n_head,
+ d_model,
+ dropout_rate,
+ attention_innner_dim,
+ attention_glu_type,
+ bias_in_glu,
+ use_pt_scaled_dot_product_attention=
+ use_pt_scaled_dot_product_attention,
+ group_size=attn_group_sizes,
+ ))
+ self.conv = ConvModule(
+ d_model,
+ ext_pw_out_channel,
+ depthwise_seperable_out_channel,
+ ext_pw_kernel_size,
+ kernel_size,
+ depthwise_multiplier,
+ dropout_rate,
+ causal,
+ batch_norm,
+ chunk_se,
+ chunk_size,
+ conv_activation,
+ conv_glu_type,
+ bias_in_glu,
+ linear_glu_in_convm,
+ export=export,
+ )
+
+ self.feed_forward_out = FeedForward(
+ d_model=d_model,
+ d_inner=d_ffn,
+ dropout_rate=dropout_rate,
+ activation=activation,
+ bias_in_glu=bias_in_glu,
+ )
+
+ self.layer_norm_att = nn.LayerNorm(d_model)
+ self.layer_norm = nn.LayerNorm(d_model)
+
+ def forward(
+ self,
+ x,
+ pos_k,
+ pos_v,
+ mask,
+ relative_attention_bias: Optional[Tensor] = None,
+ ):
+ """ConformerEncoder forward.
+
+ Args:
+ x: torch.Tensor
+ input feature of shape (batch, max_time_in, size)
+ pos_k: torch.Tensor
+ positional key embedding.
+ mask: torch.Tensor
+ mask for x (batch, max_time_in)
+ relative_attention_bias: Optional[torch.Tensor]
+ bias added to attention logits w.r.t. relative positions
+ (1, n_head, time1, time2)
+ """
+ x = x + 0.5 * self.feed_forward_in(x)
+ norm_x = self.layer_norm_att(x)
+
+ x = x + self.self_attn(
+ norm_x,
+ norm_x,
+ norm_x,
+ pos_k,
+ pos_v,
+ mask,
+ relative_attention_bias=relative_attention_bias,
+ )
+ x = x + self.conv(x)
+ x = x + 0.5 * self.feed_forward_out(x)
+
+ out = self.layer_norm(x)
+
+ return out, pos_k, pos_v, mask
+
+
+class TransformerEncoderBase(abc.ABC, nn.Module):
+ """The Base class for Transformer based encoders
+
+ Please set causal = True in streaming model
+ Args:
+ input_size: int
+ input feature dimension.
+ chunk_size: int, list(int)
+ Number of frames for each chunk
+ This variable can take 2 forms:
+ int: Used for inference, or single chunk size training
+ list(int) : Used only for variable chunk size training
+ Some examples for the 2 cases:
+ chunk_size = 12
+ chunk_size = [6, 8, 12, 24]
+ left_chunk: int, list(int)
+ Number of chunks used for masking in streaming mode.
+ This variable can take 2 forms:
+ int: Used for inference, or single chunk size training
+ list(int) : Used only for variable chunk size training. When
+ chunk_size is a list, left_chunk must be a list with same length.
+ Some examples for the 2 cases:
+ left_chunk = 6
+ left_chunk = [12, 9, 6, 3]
+ attention_dim: int, optional
+ attention dimension. default 256.
+ attention_heads: int, optional
+ the number of heads. default 4
+ input_layer: str, optional
+ input layer type before Conformer,
+ one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
+ default "conv2d"
+ cnn_out: int, optional
+ the number of CNN channels before Conformer.
+ default -1.
+ cnn_layer_norm: bool, optional
+ layer norm between Conformer and the first CNN.
+ default False.
+ time_reduction: int, optional
+ time reduction factor
+ default 4
+ dropout_rate: float, optional
+ dropout rate. default 0.1
+ padding_idx: int, optional
+ padding index for input_layer=embed
+ default -1
+ relative_attention_bias_args: dict, optional
+ use more efficient scalar bias-based relative multihead attention
+ (Q*K^T + B) implemented in cmb.basics.embedding.
+ [T5/ALiBi]RelativeAttentionLogitBias
+ usage: relative_attention_bias_args={"type": t5/alibi}
+ additional method-specific arguments can be provided (see
+ transformer_base.py)
+ positional_dropout_rate: float, optional
+ dropout rate after positional encoding. default 0.0
+ nemo_conv_settings: dict, optional
+ A dictionary of settings for NeMo Subsampling.
+ default None
+ conv2d_extra_padding: str, optional
+ Add extra padding in conv2d subsampling layers. Choices are
+ (feat, feat_time, none, True).
+ if True or feat_time, the extra padding is added into non full
+ supraframe utts in batch.
+ Default: none
+ attention_group_size: int, optional
+ the number of groups to use for attention, default 1
+ (Multi-Head Attention),
+ 1 = typical Multi-Head Attention,
+ 1 < attention_group_size < attention_heads = Grouped-Query
+ Attention
+ attention_group_size = attenion_heads = Multi-Query Attention
+ """
+
+ def __init__(
+ self,
+ input_size,
+ chunk_size,
+ left_chunk,
+ attention_dim=256,
+ attention_heads=4,
+ input_layer="nemo_conv",
+ cnn_out=-1,
+ cnn_layer_norm=False,
+ time_reduction=4,
+ dropout_rate=0.0,
+ padding_idx=-1,
+ relative_attention_bias_args=None,
+ positional_dropout_rate=0.0,
+ nemo_conv_settings=None,
+ conv2d_extra_padding: Literal["feat", "feat_time", "none",
+ True] = "none",
+ attention_group_size=1,
+ encoder_embedding_config=None,
+ ):
+ super().__init__()
+ self.input_size = input_size
+ self.input_layer = input_layer
+ self.chunk_size = chunk_size
+ self.left_chunk = left_chunk
+ self.attention_dim = attention_dim
+ self.num_heads = attention_heads
+ self.attention_group_size = attention_group_size
+ self.time_reduction = time_reduction
+ self.nemo_conv_settings = nemo_conv_settings
+ self.encoder_embedding_config = encoder_embedding_config
+
+ if self.input_layer == "nemo_conv":
+ default_nemo_conv_settings = {
+ "subsampling": "dw_striding",
+ "subsampling_factor": self.time_reduction,
+ "feat_in": input_size,
+ "feat_out": attention_dim,
+ "conv_channels": 256,
+ "subsampling_conv_chunking_factor": 1,
+ "activation": nn.ReLU(),
+ "is_causal": False,
+ }
+ # Override any of the defaults with the incoming, user settings
+ if nemo_conv_settings:
+ default_nemo_conv_settings.update(nemo_conv_settings)
+ for i in ["subsampling_factor", "feat_in", "feat_out"]:
+ assert (
+ i not in nemo_conv_settings
+ ), "{i} should be specified outside of the NeMo dictionary"
+
+ self.embed = NemoConvSubsampling(**default_nemo_conv_settings, )
+ else:
+ raise ValueError("unknown input_layer: " + input_layer)
+
+ self.pos_emb = AbsolutePositionalEncoding(attention_dim,
+ positional_dropout_rate)
+
+ self.relative_attention_bias_type = (
+ relative_attention_bias_args.get("type")
+ if relative_attention_bias_args else None)
+ if self.relative_attention_bias_type == "t5":
+ assert (self.num_heads % self.attention_group_size == 0
+ ), "attention_group_size must divide n_head"
+ self.relative_attention_bias_layer = T5RelativeAttentionLogitBias(
+ self.num_heads // self.attention_group_size,
+ max_distance=relative_attention_bias_args.get(
+ "t5_bias_max_distance", 1000),
+ symmetric=relative_attention_bias_args.get(
+ "t5_bias_symmetric", False),
+ )
+ else:
+ raise NotImplementedError
+
+ def post_init(self, init_model_config):
+
+ pretrained_speech_encoder_path = init_model_config.get(
+ "pretrained_speech_encoder_path", None)
+ if pretrained_speech_encoder_path:
+ model_state = torch.load(pretrained_speech_encoder_path,
+ map_location="cpu")
+ encoder_state_dict = {}
+ for k, v in model_state.items():
+ if "encoder." in k:
+ tmp_k = k.replace("encoder.", "")
+ encoder_state_dict[tmp_k] = v
+
+ if hasattr(self, "encoder_embedding"):
+ del self.encoder_embedding
+ self.load_state_dict(encoder_state_dict)
+
+ if not hasattr(self, "encoder_embedding"):
+ self.encoder_embedding = MeanVarianceNormLayer(
+ self.encoder_embedding_config["input_size"])
+
+ def compute_lens_change(self, feature_lens):
+ """feature_lens: int
+ return updated feature lens.
+
+ This used to return a different lambda function for each case that
+ computed the right thing. That does not work within Torchscript.
+ If you really need this to be faster, create nn.Module()-s for all
+ the cases and return one of them. Torchscript does support that.
+ """
+ if self.input_layer == "nemo_conv":
+ # Handle the special causal case
+ subsampling_causal_cond = self.nemo_conv_settings.get(
+ "subsampling", "dw_striding") in [
+ "dw_striding",
+ "striding",
+ "striding_conv1d",
+ ]
+ is_causal = self.nemo_conv_settings.get("is_causal", False)
+ if is_causal and subsampling_causal_cond:
+ lens_change = (torch.ceil(feature_lens /
+ self.time_reduction).long()
+ if isinstance(feature_lens, Tensor) else
+ math.ceil(feature_lens / self.time_reduction))
+ feature_lens_remainder = feature_lens % self.time_reduction
+ if isinstance(feature_lens, Tensor):
+ lens_change[feature_lens_remainder != 1] += 1
+ elif feature_lens_remainder != 1:
+ lens_change += 1
+ return lens_change
+ ceil_func = (math.ceil
+ if isinstance(feature_lens, int) else torch.ceil)
+ return ceil_func(feature_lens / self.time_reduction)
+
+ @abc.abstractmethod
+ def forward(self):
+ """Abstract forward method implementation."""
+
+ def _chunk_size_selection(self, chunk_size=None, left_chunk=None):
+ """If chunk size is a list, we will randomly select a chunk size."""
+
+ if chunk_size is None:
+ chunk_size = self.chunk_size
+ if left_chunk is None:
+ left_chunk = self.left_chunk
+ if isinstance(chunk_size, list):
+ # Variable chunk size during training
+ chunk_size_index = int(
+ torch.randint(low=0, high=len(chunk_size), size=(1, )))
+ chunk_size_train_eff = chunk_size[chunk_size_index]
+ if not isinstance(left_chunk, list):
+ raise ValueError(
+ "Since chunk_size is a list, left_chunk must be a list")
+ if len(left_chunk) != len(chunk_size):
+ raise ValueError(
+ "The length of left_chunk must be the same as length of "\
+ "chunk_size."
+ )
+ left_chunk_train_eff = left_chunk[chunk_size_index]
+ else:
+ chunk_size_train_eff = chunk_size
+ left_chunk_train_eff = left_chunk
+
+ return chunk_size_train_eff, left_chunk_train_eff
+
+ def _get_embed_class(self, embed):
+ # pylint: disable=protected-access
+ is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper)
+ is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel)
+ embed_class = embed
+ if is_embed_using_act_chkpt:
+ embed_class = embed._checkpoint_wrapped_module
+ if is_embed_fsdp_wrapped:
+ embed_class = embed.module
+ return embed_class
+
+ def _forward_embeddings_core(self, input_tensor, masks):
+ embed_class = self._get_embed_class(self.embed)
+ assert isinstance(embed_class, NemoConvSubsampling)
+ input_tensor, masks = self.embed(input_tensor, masks)
+ return input_tensor, masks
+
+ def _position_embedding(self, input_tensor):
+ pos_k = None
+ pos_v = None
+ if self.relative_attention_bias_layer is None:
+ input_tensor = self.pos_emb(
+ input_tensor) # default to add abs sinusoid embedding
+ return pos_k, pos_v
+
+ def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk):
+ chunk_size_train_eff, left_chunk_train_eff = \
+ self._chunk_size_selection(chunk_size, left_chunk)
+
+ # Create mask matrix for streaming
+ # S stores start index. if chunksize is 18, s is [0,18,36,....]
+ chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
+ # avoid randomness when run evaluation or decoding
+ if self.training and np.random.rand() > 0.5:
+ # Either first or last chunk is not complete.
+ # If only the last one is not complete, EOS is not effective
+ chunk_start_idx = seq_len - chunk_start_idx
+ chunk_start_idx = chunk_start_idx[::-1]
+ chunk_start_idx = chunk_start_idx[:-1]
+ chunk_start_idx = np.insert(chunk_start_idx, 0, 0)
+
+ enc_streaming_mask = (adaptive_enc_mask(
+ seq_len, chunk_start_idx,
+ left_window=left_chunk_train_eff).unsqueeze(0).expand(
+ [batch_size, -1, -1]))
+ return enc_streaming_mask
+
+ def forward_embeddings(self,
+ xs_pad,
+ masks,
+ chunk_size_nc=None,
+ left_chunk_nc=None):
+ """Forwarding the inputs through the top embedding layers
+
+ Args:
+ xs_pad: torch.Tensor
+ input tensor
+ masks: torch.Tensor
+ input mask
+ chunk_size_nc: (optional, default is None) chunk size for
+ non-causal layers
+ left_chunk_nc: (optional, default is None) # of left chunks for
+ non-causal layers
+ """
+ # pylint: disable=R0915
+ # get new lens.
+ seq_len = int(self.compute_lens_change(xs_pad.shape[1]))
+ if seq_len <= 0:
+ raise ValueError(
+ f"""The sequence length after time reduction is invalid:
+ {seq_len}. Your input feature is too short. Consider
+ filtering out the very short sentence from data
+ loader""", )
+
+ batch_size = xs_pad.shape[0]
+
+ enc_streaming_mask = self._streaming_mask(seq_len, batch_size,
+ self.chunk_size,
+ self.left_chunk)
+
+ if xs_pad.is_cuda:
+ enc_streaming_mask = enc_streaming_mask.cuda()
+ xs_pad = xs_pad.cuda()
+
+ input_tensor = xs_pad
+ input_tensor, masks = self._forward_embeddings_core(
+ input_tensor, masks)
+
+ streaming_mask = enc_streaming_mask
+ if streaming_mask is not None and masks is not None:
+ hs_mask = masks & streaming_mask
+ elif masks is not None:
+ hs_mask = masks
+ else:
+ hs_mask = streaming_mask
+
+ if chunk_size_nc is not None:
+ enc_streaming_mask_nc = self._streaming_mask(
+ seq_len, batch_size, chunk_size_nc, left_chunk_nc)
+ if xs_pad.is_cuda:
+ enc_streaming_mask_nc = enc_streaming_mask_nc.cuda()
+ if masks is not None:
+ hs_mask_nc = masks & enc_streaming_mask_nc
+ else:
+ hs_mask_nc = enc_streaming_mask_nc
+ else:
+ hs_mask_nc = None
+
+ pos_k, pos_v = self._position_embedding(input_tensor)
+
+ if chunk_size_nc is None:
+ return input_tensor, pos_k, pos_v, hs_mask, masks
+ return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc
+
+ def get_offset(self):
+ """Returns offset used when retaining inputs for decoding.
+
+ This is essentially, how many additional frames have to be added to
+ the front-end CNN input to ensure it can produce a single output.
+ So if the "padding" parameter is 0, typically offset will be > 0.
+ """
+ return get_offset(self.input_layer, self.time_reduction)
+
+
+class ConformerEncoder(TransformerEncoderBase):
+ """ConformerEncoder module.
+ see original paper for more details:
+ https://arxiv.org/abs/2005.08100
+
+ Please set causal = True in streaming model
+ Args:
+ input_size: int
+ input feature dimension.
+ chunk_size: int, list(int)
+ Number of frames for each chunk
+ This variable can take 2 forms:
+ int: Used for inference, or single chunk size training
+ list(int) : Used only for variable chunk size training
+ Some examples for the 2 cases:
+ chunk_size = 12
+ chunk_size = [6, 8, 12, 24]
+ left_chunk: int, list(int)
+ Number of chunks used for masking in streaming mode.
+ This variable can take 2 forms:
+ int: Used for inference, or single chunk size training
+ list(int) : Used only for variable chunk size training. When
+ chunk_size is a list, left_chunk must be a list with same length.
+ Some examples for the 2 cases:
+ left_chunk = 6
+ left_chunk = [12, 9, 6, 3]
+ left_chunk: int
+ number of chunks used for masking in streaming mode.
+ num_lang: int
+ This parameter is used to store the number of languages in the
+ lang_dict, only used for multiseed/multilingual models.
+ default None.
+ attention_dim: int, optional
+ attention dimension. default 256.
+ attention_heads: int, optional
+ the number of heads. default 4
+ linear_units:
+ the number of units of position-wise feed forward.
+ default 2048
+ num_block:
+ number of Transformer layer. default 6
+ dropout_rate: float, optional
+ dropout rate. default 0.1
+ input_layer: str, optional
+ input layer type before Conformer,
+ one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
+ default "conv2d"
+ causal: bool, optional
+ if set to True, convolution have no access
+ to future frames. default False.
+ batch_norm: bool, optional
+ if set to True, apply batchnorm before activation
+ in ConvModule layer of the conformer.
+ default False
+ cnn_out: int, optional
+ the number of CNN channels before Conformer.
+ default -1.
+ cnn_layer_norm: bool, optional
+ layer norm between Conformer and the first CNN.
+ default False.
+ ext_pw_out_channel: int, optional
+ the number of channel for CNN
+ before depthwise_seperable_CNN.
+ If 0 then use linear. default 0.
+ ext_pw_kernel_size: int, optional
+ kernel size of N before depthwise_seperable_CNN.
+ only work for ext_pw_out_channel > 0.
+ default 1
+ depthwise_seperable_out_channel: int, optional
+ the number of channel for
+ depthwise_seperable_CNN.
+ default 256.
+ depthwise_multiplier: int, optional
+ the number of multiplier for
+ depthwise_seperable_CNN.
+ default 1.
+ chunk_se: int, optional
+ 0 for offline SE.
+ 1 for streaming SE, where mean is computed
+ by accumulated history until current chunk_se.
+ 2 for streaming SE, where mean is computed
+ by only the current chunk.
+ default 0.
+ kernel_size: int, optional
+ the number of kernels for depthwise_seperable_CNN.
+ default 3.
+ activation: str, optional
+ FeedForward block activation.
+ one of ["relu", "swish", "sigmoid"]
+ default "relu".
+ conv_activation: str, optional
+ activation function used in ConvModule part
+ of the conformer, default "relu".
+ conv_glu_type: str, optional
+ activation used use glu in depthwise_seperable_CNN,
+ default "sigmoid"
+ bias_in_glu: bool, optional
+ if set to True, use additive bias in the weight module
+ before GLU. default True
+ linear_glu_in_convm: bool, optional
+ if set to True, use GLULinear module,
+ otherwise, used GLUPointWiseConv module.
+ default to False.
+ attention_glu_type: str
+ only work for glu_in_attention !=0
+ default "swish".
+ export: bool, optional
+ if set to True, it remove the padding from convolutional layers
+ and allow the onnx conversion for inference.
+ default False.
+ activation_checkpointing: str, optional
+ a dictionarry of {"module","interval","offload"}, where
+ "module": str
+ accept ["transformer", "attention"] to select
+ which module should do activation checkpointing.
+ "interval": int, default 1,
+ interval of applying activation checkpointing,
+ interval = 1 means that we apply checkpointing
+ on every layer (if activation), otherwise,
+ we apply it every x interval.
+ "offload": bool, default False,
+ if set to True, we offload activation to cpu and
+ reload it during backward, otherwise,
+ we recalculate activation in backward.
+ default "".
+ extra_layer_output_idx: int
+ the layer index to be exposed.
+ relative_attention_bias_args: dict, optional
+ use more efficient scalar bias-based relative multihead attention
+ (Q*K^T + B) implemented in cmb.basics.embedding.
+ [T5/ALiBi]RelativeAttentionLogitBias
+ usage: relative_attention_bias_args={"type": t5/alibi}
+ additional method-specific arguments can be provided (see
+ transformer_base.py)
+ time_reduction: int optional
+ time reduction factor
+ default 4
+ use_pt_scaled_dot_product_attention: whether to use pytorch scaled
+ dot product attention in training.
+ Default: False
+ nemo_conv_settings: dict, optional
+ A dictionary of settings for NeMo Subsampling.
+ default: None
+ usage: nemo_conv_settings=
+ {
+ "subsampling":
+ dw_striding/striding/dw_striding_conv1d/striding_conv1d,
+ "conv_channels": int,
+ "subsampling_conv_chunking_factor": int,
+ "is_causal": True/False
+ }
+ conv2d_extra_padding: str, optional
+ Add extra padding in conv2d subsampling layers. Choices are
+ (feat, feat_time, none, True)
+ Default: none
+ replication_pad_for_subsample_embedding: For batched-streaming
+ decoding, use "replication" padding for the cache at start of
+ utterance.
+ Default: False
+ attention_group_size: int, optional
+ the number of groups to use for attention, default 1
+ (Multi-Head Attention),
+ 1 = typical Multi-Head Attention,
+ 1 < attention_group_size < attention_heads = Grouped-Query
+ Attention
+ attention_group_size = attenion_heads = Multi-Query Attention
+ """
+
+ extra_multi_layer_output_idxs: List[int]
+
+ def __init__( # pylint: disable-all
+ self,
+ input_size,
+ chunk_size,
+ left_chunk,
+ num_lang=None,
+ attention_dim=256,
+ attention_heads=4,
+ linear_units=2048,
+ num_blocks=6,
+ dropout_rate=0.1,
+ input_layer="nemo_conv",
+ causal=True,
+ batch_norm=False,
+ cnn_out=-1,
+ cnn_layer_norm=False,
+ ext_pw_out_channel=0,
+ ext_pw_kernel_size=1,
+ depthwise_seperable_out_channel=256,
+ depthwise_multiplier=1,
+ chunk_se=0,
+ kernel_size=3,
+ activation="relu",
+ conv_activation="relu",
+ conv_glu_type="sigmoid",
+ bias_in_glu=True,
+ linear_glu_in_convm=False,
+ attention_glu_type="swish",
+ export=False,
+ extra_layer_output_idx=-1,
+ extra_multi_layer_output_idxs=[], # noqa
+ activation_checkpointing="",
+ relative_attention_bias_args=None,
+ time_reduction=4,
+ use_pt_scaled_dot_product_attention=False,
+ nemo_conv_settings=None,
+ conv2d_extra_padding: Literal["feat", "feat_time", "none",
+ True] = "none",
+ replication_pad_for_subsample_embedding=False,
+ attention_group_size=1,
+ encoder_embedding_config=None,
+ ):
+ super().__init__(
+ input_size,
+ chunk_size,
+ left_chunk,
+ attention_dim,
+ attention_heads,
+ input_layer,
+ cnn_out,
+ cnn_layer_norm,
+ time_reduction,
+ dropout_rate=dropout_rate,
+ relative_attention_bias_args=relative_attention_bias_args,
+ positional_dropout_rate=0.0,
+ nemo_conv_settings=nemo_conv_settings,
+ conv2d_extra_padding=conv2d_extra_padding,
+ attention_group_size=attention_group_size,
+ encoder_embedding_config=encoder_embedding_config,
+ )
+ self.num_blocks = num_blocks
+ self.num_lang = num_lang
+ self.kernel_size = kernel_size
+ self.embed = embedding_checkpoint_wrapper(activation_checkpointing)(
+ self.embed)
+ self.replication_pad_for_subsample_embedding: bool = (
+ replication_pad_for_subsample_embedding)
+ assert (self.num_heads % attention_group_size == 0
+ ), "attention_group_size must divide n_head"
+ self.num_heads_k = self.num_heads // attention_group_size
+
+ self.encoders = repeat(
+ num_blocks,
+ lambda i: encoder_checkpoint_wrapper(activation_checkpointing,
+ ConformerEncoderLayer, i)
+ (ConformerEncoderLayer(
+ d_model=attention_dim,
+ ext_pw_out_channel=ext_pw_out_channel,
+ depthwise_seperable_out_channel=
+ depthwise_seperable_out_channel,
+ depthwise_multiplier=depthwise_multiplier,
+ n_head=attention_heads,
+ d_ffn=linear_units,
+ ext_pw_kernel_size=ext_pw_kernel_size,
+ kernel_size=kernel_size,
+ dropout_rate=dropout_rate,
+ causal=causal,
+ batch_norm=batch_norm,
+ activation=activation,
+ chunk_se=chunk_se,
+ chunk_size=chunk_size,
+ conv_activation=conv_activation,
+ conv_glu_type=conv_glu_type,
+ bias_in_glu=bias_in_glu,
+ linear_glu_in_convm=linear_glu_in_convm,
+ attention_glu_type=attention_glu_type,
+ activation_checkpointing=attn_checkpointing(
+ activation_checkpointing, i),
+ export=export,
+ use_pt_scaled_dot_product_attention=
+ use_pt_scaled_dot_product_attention,
+ attn_group_sizes=attention_group_size,
+ )),
+ )
+ self.extra_layer_output_idx = extra_layer_output_idx
+ self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
+ # Make a zeros scalar we can use in get_initial_state to determine
+ # the device and the needed dtype:
+ self.register_buffer("dev_type", torch.zeros(()), persistent=False)
+
+ def init_relative_attention_bias(self, input_tensor):
+ if self.relative_attention_bias_layer:
+ return self.relative_attention_bias_layer(input_tensor)
+
+ def calculate_hs_mask(self, xs_pad, device, mask):
+ max_audio_length = xs_pad.shape[1]
+ batch_size = xs_pad.shape[0]
+ enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size,
+ self.chunk_size,
+ self.left_chunk)
+ enc_streaming_mask = enc_streaming_mask.to(device)
+ if mask is None:
+ return enc_streaming_mask
+
+ feature_lens = mask.sum(1)
+ padding_length = feature_lens
+ pad_mask = (torch.arange(0, max_audio_length,
+ device=device).expand(padding_length.size(0),
+ -1)
+ < padding_length.unsqueeze(1))
+ pad_mask = pad_mask.unsqueeze(1)
+ pad_mask = pad_mask & enc_streaming_mask
+ return pad_mask
+
+ @torch.jit.ignore
+ def forward(self, xs_pad, masks):
+ """Conformer Forward function
+
+ Args:
+ xs_pad: torch.Tensor
+ input tensor
+ masks: torch.Tensor
+ post-embedding input lengths
+ """
+ xs_pad = self.encoder_embedding(xs_pad)
+ input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings(
+ xs_pad, masks)
+
+ unfolded = False
+ ori_bz, seq_len, D = input_tensor.shape
+ max_seq_len = 500 #maximum position for absolute positional encoding
+ if seq_len > max_seq_len:
+ # audio sequence is longer than max_seq_len, unfold it into chunks
+ # of max_seq_len
+ unfolded = True
+ # the unfold op will drop residual frames, pad it to the multiple
+ # of max_seq_len
+ if seq_len % max_seq_len > 0:
+ chunk_pad_size = max_seq_len - (seq_len % max_seq_len)
+ else:
+ chunk_pad_size = 0
+ if chunk_pad_size > 0:
+ input_tensor_pad = F.pad(input_tensor,
+ (0, 0, 0, chunk_pad_size), "constant",
+ 0)
+ input_tensor = input_tensor_pad.to(input_tensor.device)
+ input_tensor = unfold_tensor(input_tensor, max_seq_len)
+ if masks is not None:
+ # revise hs_mask here because the previous calculated hs_mask
+ # did not consider extra pad
+ subsampled_pad_mask = masks.squeeze(
+ 1) # [bz, subsampled_unmask_seq_len]
+ extra_padded_subsamlped_pad_mask = F.pad(
+ subsampled_pad_mask, (0, chunk_pad_size), "constant",
+ False) # extra padding to the pad mask
+ extra_padded_subsamlped_pad_mask = \
+ extra_padded_subsamlped_pad_mask.unsqueeze(-1).float()
+ masks_unfold = unfold_tensor(
+ extra_padded_subsamlped_pad_mask, max_seq_len
+ ) # unfold the pad mask like we did to the input tensor
+ masks_unfold = masks_unfold.squeeze(
+ -1).bool() # unfold op does not support bool tensor
+ else:
+ masks_unfold = None
+ hs_mask = self.calculate_hs_mask(
+ input_tensor, input_tensor.device, masks_unfold
+ ) # calculate hs_mask based on the unfolded pad mask
+
+ # layer_emb = None
+
+ relative_attention_bias = self.init_relative_attention_bias(
+ input_tensor)
+
+ _simplified_path = (self.extra_layer_output_idx == -1
+ and relative_attention_bias is None)
+
+ if _simplified_path:
+ input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v,
+ hs_mask)
+ else:
+ for i, layer in enumerate(self.encoders):
+ input_tensor, _, _, _ = layer(
+ input_tensor,
+ pos_k,
+ pos_v,
+ hs_mask,
+ relative_attention_bias=relative_attention_bias,
+ )
+
+ # if i == self.extra_layer_output_idx:
+ # layer_emb = input_tensor
+
+ if unfolded:
+ embed_dim = input_tensor.shape[-1]
+ input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim)
+ # if we ever padded before unfolding, we need to remove the padding
+ if chunk_pad_size > 0:
+ input_tensor = input_tensor[:, :-chunk_pad_size, :]
+
+ return input_tensor, masks # , layer_emb
+
+ def gradient_checkpointing_enable(self):
+ pass
+
+
+class WindowQformer(nn.Module):
+ """Window-level Qformer"""
+
+ def __init__(
+ self,
+ window_size: int = 8,
+ num_queries: int = 1,
+ num_blocks: int = 2,
+ attention_dim: int = 512,
+ attention_heads: int = 8,
+ linear_units: int = 2048,
+ dropout_rate: float = 0.0,
+ normalize_before: bool = True,
+ ):
+ super().__init__()
+
+ self.decoders = nn.ModuleList([
+ nn.TransformerDecoderLayer(
+ d_model=attention_dim,
+ nhead=attention_heads,
+ dim_feedforward=linear_units,
+ dropout=dropout_rate,
+ activation="relu",
+ batch_first=True,
+ norm_first=normalize_before, # TODO need to verify
+ ) for _ in range(num_blocks)
+ ])
+
+ self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim))
+ self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12)
+ if normalize_before else None)
+ self.window_size = window_size
+ self.gradient_checkpointing_enable = False
+
+ def enable_gradient_checkpointing(self):
+ self.gradient_checkpointing_enable = True
+
+ def disable_gradient_checkpointing(self):
+ self.gradient_checkpointing_enable = False
+
+ def forward(self, audio_embed, mask, embed_len=None):
+ """forward decoder"""
+ # audio_embed: N x T x D => N x D x T
+
+ audio_embed = audio_embed.transpose(1, 2)
+ # audio_embed: N x D x 1 x T => N x DK x T'
+ padding = audio_embed.shape[-1] % self.window_size
+ if padding > 0:
+ audio_embed = F.pad(audio_embed, (0, self.window_size - padding),
+ "constant", 0)
+
+ embed_chunk = F.unfold(
+ audio_embed[..., None, :],
+ kernel_size=(1, self.window_size),
+ stride=(1, self.window_size),
+ )
+ bsz, _, slen = embed_chunk.shape
+ # N x D x K x T'
+ embed_chunk = embed_chunk.view(bsz, -1, self.window_size, slen)
+ # N x T' x K x D
+ embed_chunk = embed_chunk.transpose(1, 3).contiguous()
+ # NT' x K x D
+ embed_chunk = embed_chunk.view(bsz * slen, self.window_size, -1)
+ # NT' x 1 x D
+ q = self.queries.expand(bsz * slen, -1, -1)
+ for layer in self.decoders:
+ if self.gradient_checkpointing_enable and self.training:
+ q = checkpoint(
+ layer.__call__,
+ q,
+ embed_chunk,
+ None,
+ mask,
+ use_reentrant=True,
+ )
+ else:
+ q = layer(tgt=q,
+ memory=embed_chunk,
+ tgt_mask=None,
+ memory_mask=mask)
+
+ if self.after_norm is not None:
+ q = self.after_norm(q)
+
+ if embed_len is not None:
+ embed_len = embed_len // self.window_size
+ # N x T' x D
+ out = q.view(bsz, slen, -1)
+
+ return out, embed_len
+
+
+class AudioEmbedding(nn.Module):
+ """Image embedding."""
+
+ def __init__(self, config: PretrainedConfig, **kwargs) -> None:
+ super().__init__()
+ self.config = config
+ # n_embed or hidden_size for text LM
+ hidden_size = (config.n_embd
+ if hasattr(config, "n_embd") else config.hidden_size)
+
+ if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"):
+ embd_drop = (config.embd_pdrop if hasattr(config, "embd_pdrop")
+ else config.embed_pdrop)
+ self.drop = nn.Dropout(embd_drop)
+ else:
+ self.drop = None
+
+ # self.wte = nn.Embedding(config.vocab_size, hidden_size)
+
+ audio_dim_out = (
+ None # Set this variable according to the actual audio processor
+ )
+ self.layer_idx = -2
+
+ if (isinstance(config.audio_processor, dict)
+ and config.audio_processor.get("name", None) == "cascades"):
+ encoder_config = config.audio_processor.get("config", None)
+ assert encoder_config is not None
+ self.encoder = ConformerEncoder(**encoder_config)
+
+ # fake initialization, create encoder_embedding layer only so that
+ # in decoding, all parameters can be loaded in
+ # from_pretrained_function in training, we do post init after
+ # from_pretrained function to make sure the correct initialization
+ self.encoder.post_init({})
+
+ audio_dim_out = encoder_config["attention_dim"]
+ n_mels = encoder_config["input_size"]
+ else:
+ raise NotImplementedError("")
+
+ assert (audio_dim_out
+ is not None), "Remember to set values for audio_dim_out"
+ self.audio_dim_out = audio_dim_out
+ self.audio_dim_in = n_mels
+
+ self.freeze_audio_processor = kwargs.get("freeze_audio_processor",
+ False)
+
+ self.downsample_rate = kwargs.get("downsample_rate", 1)
+
+ if kwargs.get("use_qformer", False):
+ qformer_config = kwargs.get("qformer_config", {})
+ qformer_config["attention_dim"] = audio_dim_out
+ self.qformer = WindowQformer(**qformer_config)
+ else:
+ self.qformer = None
+
+ if kwargs.get("use_conv_downsample", False):
+ assert (self.qformer is None
+ ), "don't support use qformer and conv downsample together"
+ nemo_conv_settings = kwargs.get("nemo_conv_settings", {})
+ default_nemo_conv_settings = {
+ "subsampling": "dw_striding",
+ "subsampling_factor": self.downsample_rate,
+ "feat_in": audio_dim_out,
+ "feat_out": audio_dim_out,
+ "conv_channels": 256,
+ "subsampling_conv_chunking_factor": 1,
+ "activation": nn.ReLU(),
+ "is_causal": False,
+ }
+ # Override any of the defaults with the incoming, user settings
+ if nemo_conv_settings:
+ default_nemo_conv_settings.update(nemo_conv_settings)
+ for i in ["subsampling_factor", "feat_in", "feat_out"]:
+ assert (
+ i not in nemo_conv_settings
+ ), "{i} should be specified outside of the NeMo dictionary"
+
+ self.conv_ds = NemoConvSubsampling(**default_nemo_conv_settings, )
+ else:
+ self.conv_ds = None
+
+ enable_gradient_checkpointing = kwargs.get(
+ "enable_gradient_checkpointing", False)
+ if enable_gradient_checkpointing:
+ self.encoder.gradient_checkpointing_enable()
+
+ if self.qformer:
+ self.qformer.enable_gradient_checkpointing()
+
+ projection_cls = kwargs.get("projection_cls", "linear")
+ if projection_cls == "linear":
+ self.audio_projection = nn.Linear(audio_dim_out, hidden_size)
+ elif projection_cls == "mlp":
+ # follow llava-v1.5's implementation
+ # (do not use image_projection and image_proj_norm)
+ dim_projection = hidden_size
+ depth = 2
+ self.linear_downsample_rate = (1 if (self.qformer or self.conv_ds)
+ else self.downsample_rate)
+ layers = [
+ nn.Linear(audio_dim_out * self.linear_downsample_rate,
+ dim_projection)
+ ]
+ for _ in range(1, depth):
+ layers.extend(
+ [nn.GELU(),
+ nn.Linear(dim_projection, dim_projection)])
+ self.audio_projection = nn.Sequential(*layers)
+ # NOTE vision-speech tasks use a separate projection layer
+ layers = [
+ nn.Linear(audio_dim_out * self.linear_downsample_rate,
+ dim_projection)
+ ]
+ for _ in range(1, depth):
+ layers.extend(
+ [nn.GELU(),
+ nn.Linear(dim_projection, dim_projection)])
+ self.audio_projection_for_vision = nn.Sequential(*layers)
+ else:
+ raise NotImplementedError(
+ f"projection_cls = {projection_cls}, not implemented")
+
+ # TODO: audio sequence compression - Qformer
+ self.vocab_size = config.vocab_size
+ self.input_embeds = None
+ self.audio_embed_sizes = None
+
+ def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None:
+ self.input_embeds = input_embeds
+
+ def set_audio_embed_sizes(self,
+ audio_embed_sizes: torch.LongTensor) -> None:
+ self.audio_embed_sizes = audio_embed_sizes
+
+ def get_audio_features(
+ self,
+ input_embeds: torch.FloatTensor,
+ audio_attention_mask: torch.Tensor = None,
+ audio_projection_mode: str = "speech",
+ ):
+
+ if self.freeze_audio_processor:
+ with torch.no_grad():
+ audio_features, masks = self.encoder(input_embeds,
+ audio_attention_mask)
+ else:
+ audio_features, masks = self.encoder(input_embeds,
+ audio_attention_mask)
+
+ if self.qformer is not None:
+ audio_features, _ = self.qformer(audio_features, mask=None)
+
+ if self.conv_ds is not None:
+ if masks is not None:
+ masks = masks.squeeze(1)
+
+ audio_features, masks = self.conv_ds(audio_features, mask=masks)
+
+ if self.linear_downsample_rate != 1:
+ bs, seq_len, feat_dim = audio_features.size()
+ padding = seq_len % self.linear_downsample_rate
+ if padding > 0:
+ audio_features = F.pad(
+ audio_features,
+ (0, 0, 0, self.linear_downsample_rate - padding),
+ "constant",
+ 0,
+ )
+
+ seq_len = audio_features.size(1)
+ audio_features = audio_features.view(
+ bs,
+ seq_len // self.linear_downsample_rate,
+ feat_dim * self.linear_downsample_rate,
+ )
+
+ if audio_projection_mode == 'speech':
+ audio_set_tensor = self.audio_projection(audio_features)
+ elif audio_projection_mode == 'vision':
+ audio_set_tensor = self.audio_projection_for_vision(audio_features)
+ else:
+ raise ValueError(
+ f"audio_projection_mode = {audio_projection_mode} not "\
+ "implemented"
+ )
+
+ return audio_set_tensor
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ input_embeds: torch.FloatTensor,
+ audio_embed_sizes,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ """
+ arguments:
+ input_ids: input text ids (B, U)
+ input_embeds: audio features (B, T, D) B: num audios in a sequence
+ """
+ assert input_embeds is not None and len(input_embeds) == len(
+ audio_embed_sizes)
+
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+
+ with torch.no_grad():
+ positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero(
+ as_tuple=False)
+
+ if not isinstance(input_embeds, list):
+ input_embeds = [input_embeds]
+
+ audio_projection_mode = kwargs.get("audio_projection_mode", "speech")
+ audio_set_tensor = [
+ self.get_audio_features(
+ input_embed, audio_projection_mode=audio_projection_mode)
+ for input_embed in input_embeds
+ ]
+
+ with torch.no_grad():
+ input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
+
+ if "wte" in kwargs:
+ # we use the token embedding layer from the huggingface model, this
+ # is REQUIRED to make sure we are using the loaded weights.
+ hidden_states = kwargs["wte"](input_ids)
+ else:
+ # otherwise, we use token embedding in pretrained mixformer from
+ # phi team
+ hidden_states = self.wte(input_ids)
+
+ if len(positions.tolist()) > 0:
+ assert sum(audio_embed_sizes) == len(
+ positions
+ ), "please ensure the encoder outputs have the same length as"\
+ " defined in input_ids!"
+ idx = 0
+ for i in range(len(audio_embed_sizes)):
+ cnt = audio_embed_sizes[i]
+ assert audio_set_tensor[i].shape[0] == 1
+ hidden_states[
+ positions[idx, 0],
+ positions[idx, 1]:positions[idx, 1] + cnt,
+ ] = (audio_set_tensor[i][0, :audio_embed_sizes[i], :].to(
+ hidden_states.dtype).to(hidden_states.device))
+ idx += cnt
+
+ else:
+ if self.training:
+ # hidden_states[:, 0:img_set_tensor.shape[0]] =
+ # hidden_states[:, 0:img_set_tensor.shape[0]] +
+ # 0 * img_set_tensor.to(hidden_states.dtype)
+ # .to(hidden_states.device)
+ hidden_states[:, 0:1] = hidden_states[:, 0:1] + \
+ 0 * audio_set_tensor[:, 0:1].to(hidden_states.dtype)\
+ .to(hidden_states.device)
+
+ if self.drop is not None:
+ hidden_states = self.drop(hidden_states)
+ return hidden_states
diff --git a/vllm/model_executor/models/phi4mm_utils.py b/vllm/model_executor/models/phi4mm_utils.py
new file mode 100644
index 0000000000000..16b62c60836e7
--- /dev/null
+++ b/vllm/model_executor/models/phi4mm_utils.py
@@ -0,0 +1,1969 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+# Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com)
+# but implemented by the Phi-Speech team
+#!/usr/bin/env python3
+import math
+from functools import partial
+from typing import Callable, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
+ CheckpointImpl, checkpoint_wrapper, offload_wrapper)
+
+
+class Block(nn.Module):
+ """Block abstract module"""
+
+ def __init__(self, input_size, output_size):
+ super().__init__()
+ self.input_size = input_size
+ self.output_size = output_size
+
+
+def get_activation(name="relu"):
+ """Select an activation function by name
+
+ Args:
+ name: str
+ activation function name,
+ one of ["relu", "gelu", "swish", "sigmoid"],
+ default "relu".
+ """
+ name = name.lower()
+ if name == "relu":
+ return nn.ReLU(inplace=True)
+ if name == "gelu":
+ return nn.GELU()
+ if name == "swish":
+ return Swish()
+ if name == "sigmoid":
+ return torch.nn.Sigmoid()
+ return nn.Identity()
+
+
+def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0):
+ """
+ The function is very important for Transformer Transducer Streaming mode
+ Args:
+ xs_len (int): sequence length
+ chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48].
+ It also supports adaptive chunk size [0,10,15,45]
+ left_window (int): how many left chunks can be seen
+ right_window (int): how many right chunks can be seen. It is used for
+ chunk overlap model.
+ Returns:
+ mask (torch.Tensor): a mask tensor for streaming model
+ Torch 1.0.1
+ tensor([[1., 1., 0., 0.],
+ [0., 1., 1., 0.],
+ [0., 0., 1., 1.]])
+ Torch 1.4.1
+ tensor([[True., True., False., False.],
+ [False., True., True., False.],
+ [False., False., True., True.]])
+ """
+ chunk_start_idx = torch.Tensor(chunk_start_idx).long(
+ ) # first idx of each chunk, such as [0,18,36,48].
+ start_pad = torch.nn.functional.pad(
+ chunk_start_idx,
+ (1, 0)) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48]
+ end_pad = torch.nn.functional.pad(
+ chunk_start_idx, (0, 1), value=x_len
+ ) # append x_len to the end, so it becomes [0,18,36,48, x_len]
+ seq_range = torch.arange(0,
+ x_len).unsqueeze(-1) # seq_range size: [x_len, 1]
+ idx = ((seq_range < end_pad) &
+ (seq_range >= start_pad)).nonzero()[:, 1] # idx size: [x_len]
+ # boundary = end_pad[idx] # boundary size: [x_len]
+ seq_range_expand = (torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
+ ) # seq_range_expand size [x_len, x_len]
+ idx_left = idx - left_window
+ idx_left[idx_left < 0] = 0
+ boundary_left = start_pad[idx_left]
+ mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
+ idx_right = idx + right_window
+ idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
+ boundary_right = end_pad[idx_right]
+ mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
+ return mask_left & mask_right
+
+
+class Swish(nn.Module):
+ """Implement Swish activation module.
+ From https://arxiv.org/pdf/2005.03191.pdf
+
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.act_fn = nn.Sigmoid()
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Apply Swish function
+
+ Args:
+ x: torch.Tensor
+ Input.
+ """
+ return x * self.act_fn(x)
+
+
+class GLU(nn.Module):
+ """Implement Gated Linear Unit (GLU) module"""
+
+ def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None:
+ super().__init__()
+ self.dim = dim
+ self.act_name = act_name.lower()
+
+ if self.act_name == "relu":
+ self.act_fn = nn.ReLU(inplace=True)
+ elif self.act_name == "gelu":
+ self.act_fn = nn.GELU()
+ elif self.act_name == "swish":
+ self.act_fn = Swish()
+ elif self.act_name == "sigmoid":
+ self.act_fn = nn.Sigmoid()
+ else:
+ self.act_fn = nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ """GLU forward
+ Apply Swish function on the first half of input matrices
+ with sigmoid of the second half.
+
+ Args:
+ x: torch.Tensor
+ Input.
+
+ """
+ half_x, gate = x.chunk(2, dim=self.dim)
+ return half_x * self.act_fn(gate)
+
+
+# TODO: Abdel, this can be improved using GLU module
+class GLUPointWiseConv(nn.Module):
+ """GLUPointWiseConv module
+ used for conformer architecture,
+ for more details see:
+ https://arxiv.org/pdf/2005.08100v1.pdf
+
+ Args:
+ input_dim: int
+ input channel size.
+ output_dim: int
+ output channel size.
+ kernel_size: int
+ kernel size
+ glu_type: str, optional
+ activation function one of
+ ["sigmoid", "relu", "gelu"]
+ default "sigmoid".
+ bias_in_glu: bool, optional
+ use addtive bias in glu
+ causal: bool, optional
+ if set to True, padding is set to the half of
+ kernel size, ie, convolution can't see future frames.
+ default False.
+
+ """
+
+ def __init__(
+ self,
+ input_dim,
+ output_dim,
+ kernel_size,
+ glu_type="sigmoid",
+ bias_in_glu=True,
+ causal=False,
+ ):
+ super().__init__()
+
+ self.glu_type = glu_type
+ self.output_dim = output_dim
+ self.bias_in_glu = bias_in_glu
+ if causal:
+ self.ext_pw_conv_1d = nn.Conv1d(
+ input_dim,
+ output_dim * 2,
+ kernel_size,
+ 1,
+ padding=(kernel_size - 1),
+ )
+ else:
+ self.ext_pw_conv_1d = nn.Conv1d(
+ input_dim,
+ output_dim * 2,
+ kernel_size,
+ 1,
+ padding=(kernel_size - 1) // 2,
+ )
+
+ if glu_type == "sigmoid":
+ self.glu_act = nn.Sigmoid()
+ elif glu_type == "relu":
+ self.glu_act = nn.ReLU()
+ elif glu_type == "gelu":
+ self.glu_act = nn.GELU()
+ elif glu_type == "swish":
+ self.glu_act = Swish()
+ else:
+ raise ValueError(f"Unsupported activation type {self.glu_act}")
+
+ if bias_in_glu:
+ self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1))
+ self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1))
+
+ def forward(self, x):
+ """
+ Args:
+ x: torch.Tensor
+ input tensor
+ """
+ # to be consistent with GLULinear, we assume the input always has the
+ # #channel (#dim) in the last dimension of the tensor, so need to
+ # switch the dimension first for 1D-Conv case
+ x = x.permute([0, 2, 1])
+ x = self.ext_pw_conv_1d(x)
+ if self.glu_type == "bilinear":
+ if self.bias_in_glu:
+ x = (x[:, 0:self.output_dim, :] + self.b1) * (
+ x[:, self.output_dim:self.output_dim * 2, :] + self.b2)
+ else:
+ x = (x[:, 0:self.output_dim, :]) * (
+ x[:, self.output_dim:self.output_dim * 2, :])
+ else:
+ if self.bias_in_glu:
+ x = (x[:, 0:self.output_dim, :] + self.b1) * self.glu_act(
+ x[:, self.output_dim:self.output_dim * 2, :] + self.b2)
+ else:
+ x = (x[:, 0:self.output_dim, :]) * self.glu_act(
+ x[:, self.output_dim:self.output_dim * 2, :])
+
+ x = x.permute([0, 2, 1])
+ return x
+
+
+class DepthWiseSeperableConv1d(nn.Module):
+ """DepthWiseSeperableConv1d module used in Convnet module
+ for the conformer, for more details see:
+ https://arxiv.org/pdf/2005.08100v1.pdf
+
+ Args:
+ input_dim: int
+ input channel size.
+ depthwise_seperable_out_channel: int
+ if set different to 0, the number of
+ depthwise_seperable_out_channel will be used as a channel_out
+ of the second conv1d layer.
+ otherwise, it equal to 0, the second conv1d layer is skipped.
+ kernel_size: int
+ kernel_size
+ depthwise_multiplier: int
+ number of input_dim channels duplication. this value
+ will be used to compute the hidden channels of the Conv1D.
+ padding: int, optional
+ padding for the conv1d,
+ default: 0.
+
+ """
+
+ def __init__(
+ self,
+ input_dim,
+ depthwise_seperable_out_channel,
+ kernel_size,
+ depthwise_multiplier,
+ padding=0,
+ ):
+ super().__init__()
+
+ self.dw_conv = nn.Conv1d(
+ input_dim,
+ input_dim * depthwise_multiplier,
+ kernel_size,
+ 1,
+ padding=padding,
+ groups=input_dim,
+ )
+
+ if depthwise_seperable_out_channel != 0:
+ self.pw_conv = nn.Conv1d(
+ input_dim * depthwise_multiplier,
+ depthwise_seperable_out_channel,
+ 1,
+ 1,
+ 0,
+ )
+ else:
+ self.pw_conv = nn.Identity()
+ self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
+
+ def forward(self, x):
+ """
+
+ Args:
+ x: torch.Tensor
+ input tensor
+ """
+ x = self.dw_conv(x)
+ if self.depthwise_seperable_out_channel != 0:
+ x = self.pw_conv(x)
+ return x
+
+
+class ConvModule(nn.Module):
+ """ConvModule Module for the conformer block.
+ for more details see:
+ https://arxiv.org/pdf/2005.08100v1.pdf
+
+ Args:
+ input_dim: int
+ input channel size.
+ ext_pw_out_channel: int
+ if > 0, ext_pw_out_channel is a dim channel size
+ for the last pointwise conv after swish activation.
+ depthwise_seperable_out_channel: int
+ if set different to 0, the number of
+ depthwise_seperable_out_channel
+ will be used as a channel_out of the second conv1d layer.
+ otherwise, it equal to 0, the second conv1d layer is skipped.
+ ext_pw_kernel_size: int
+ kernel size of the conv pointwise of the conformer.
+ kernel_size: int
+ kernel size.
+ depthwise_multiplier: int
+ number of input_dim channels duplication. this value
+ will be used to compute the hidden channels of the Conv1D.
+ dropout_rate: float
+ dropout rate.
+ causal: bool, optional
+ if set to True, convolution have no access
+ to future frames. default False.
+ batch_norm: bool, optional
+ if set to True, apply batchnorm before activation.
+ default False
+ chunk_se: int, optional
+ 0 for offline SE.
+ 1 for streaming SE, where mean is computed
+ by accumulated history until current chunk_se.
+ 2 for streaming SE, where mean is computed
+ by only the current chunk.
+ chunk_size: int, optional
+ chunk size for cnn. default 18
+ activation: str, optional
+ activation function used in ConvModule,
+ default: "relu".
+ glu_type: str, optional
+ activation function used for the glu,
+ default: "sigmoid".
+ bias_in_glu: bool, optional
+ if set to True, use additive bias in the weight module
+ before GLU.
+ linear_glu_in_convm: bool, optional
+ if set to True, use GLULinear module,
+ otherwise, used GLUPointWiseConv module.
+ default to False.
+ export: bool, optional,
+ if set to True, padding is equal to 0. This is for inference,
+ or onnx export. Typically this is set by the export program or
+ the decoder program, and it isn't present in your config file.
+ default False
+ """
+
+ def __init__(
+ self,
+ input_dim,
+ ext_pw_out_channel,
+ depthwise_seperable_out_channel,
+ ext_pw_kernel_size,
+ kernel_size,
+ depthwise_multiplier,
+ dropout_rate,
+ causal=False,
+ batch_norm=False,
+ chunk_se=0,
+ chunk_size=18,
+ activation="relu",
+ glu_type="sigmoid",
+ bias_in_glu=True,
+ linear_glu_in_convm=False,
+ export=False,
+ ):
+ super().__init__()
+ self.layer_norm = nn.LayerNorm(input_dim)
+ self.input_dim = input_dim
+ self.ext_pw_out_channel = ext_pw_out_channel
+ self.ext_pw_kernel_size = ext_pw_kernel_size
+ self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
+ self.glu_type = glu_type
+ self.bias_in_glu = bias_in_glu
+ self.linear_glu_in_convm = linear_glu_in_convm
+ self.causal = causal
+
+ self._add_ext_pw_layer()
+
+ self.batch_norm = batch_norm
+ self.kernel_size = kernel_size
+
+ if batch_norm:
+ self.bn_layer = nn.BatchNorm1d(input_dim)
+
+ self.act = get_activation(activation)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.export = export
+
+ if causal:
+ padding = 0 if export else kernel_size - 1
+ else:
+ padding = (kernel_size - 1) // 2
+
+ self.dw_sep_conv_1d = DepthWiseSeperableConv1d(
+ input_dim,
+ depthwise_seperable_out_channel,
+ kernel_size,
+ depthwise_multiplier,
+ padding=padding,
+ )
+
+ if depthwise_seperable_out_channel != 0:
+ if input_dim != depthwise_seperable_out_channel:
+ self.ln2 = nn.Linear(depthwise_seperable_out_channel,
+ input_dim)
+ else:
+ if depthwise_multiplier != 1:
+ self.ln2 = nn.Linear(input_dim * depthwise_multiplier,
+ input_dim)
+
+ def _add_ext_pw_layer(self):
+ """
+ This function is an extension of __init__ function
+ and dedicated to the convolution module creation
+ of the conformer.
+ """
+ self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = (
+ nn.Identity()) # jit hacks.
+ self.squeeze_excitation = nn.Identity() # jit.
+ self.apply_ln1 = self.fix_len1 = False # jit.
+
+ if self.ext_pw_out_channel != 0:
+ if self.causal:
+ self.ext_pw_conv_1d = nn.Conv1d(
+ self.input_dim,
+ self.ext_pw_out_channel,
+ self.ext_pw_kernel_size,
+ 1,
+ padding=(self.ext_pw_kernel_size - 1),
+ )
+ if self.ext_pw_kernel_size > 1:
+ self.fix_len1 = True
+ else:
+ self.fix_len1 = False
+ else:
+ self.ext_pw_conv_1d = nn.Conv1d(
+ self.input_dim,
+ self.ext_pw_out_channel,
+ self.ext_pw_kernel_size,
+ 1,
+ padding=(self.ext_pw_kernel_size - 1) // 2,
+ )
+ self.fix_len1 = False
+
+ if self.linear_glu_in_convm:
+ self.glu = GLULinear(
+ self.input_dim,
+ self.ext_pw_out_channel,
+ self.glu_type,
+ self.bias_in_glu,
+ )
+ else:
+ self.glu = GLUPointWiseConv(
+ self.input_dim,
+ self.ext_pw_out_channel,
+ self.ext_pw_kernel_size,
+ self.glu_type,
+ self.bias_in_glu,
+ self.causal,
+ )
+
+ if self.input_dim != self.ext_pw_out_channel:
+ self.apply_ln1 = True
+ self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim)
+ else:
+ self.apply_ln1 = False
+ else:
+ self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3))
+ self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3))
+
+ def forward(self, x):
+ """ConvModule Forward.
+
+ Args:
+ x: torch.Tensor
+ input tensor.
+ """
+ x = self.layer_norm(x)
+
+ if self.ext_pw_out_channel != 0:
+ x = self.glu(x)
+ if self.causal and self.ext_pw_kernel_size > 1:
+ x = x[:, :-(self.ext_pw_kernel_size - 1), :]
+ if self.apply_ln1:
+ x = self.ln1(x)
+ else:
+ x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0]
+ x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1]
+ x = x_0 + x_1
+
+ x = x.permute([0, 2, 1])
+
+ x = self.dw_sep_conv_1d(x)
+ if self.causal and self.kernel_size > 1:
+ x = x[:, :, :-(self.kernel_size - 1)]
+ if hasattr(self, "ln2"):
+ x = x.permute([0, 2, 1])
+ x = self.ln2(x)
+ x = x.permute([0, 2, 1])
+ if self.batch_norm:
+ x = self.bn_layer(x)
+ x = self.act(x)
+
+ if self.ext_pw_out_channel != 0:
+ x = self.ext_pw_conv_1d(x)
+ if self.fix_len1:
+ x = x[:, :, :-(self.ext_pw_kernel_size - 1)]
+
+ if self.apply_ln1:
+ x = x.permute([0, 2, 1])
+ x = self.ln1(x)
+ x = x.permute([0, 2, 1])
+
+ x = x.permute([0, 2, 1])
+ else:
+ x = x.unsqueeze(1).permute([0, 1, 3, 2])
+ x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2]
+ x = x.squeeze(1)
+
+ x = self.dropout(x)
+ return x
+
+
+class GLULinear(nn.Module):
+ """Linear + GLU module
+
+ Args:
+ input_dim: int
+ input size
+ output_dim: int
+ output size.
+ glu_type:
+ activation function name used in glu module.
+ default "sigmoid" (swish function).
+ bias_in_glu: bool, optional
+ If True, the addtive bias is added. Default False.
+ """
+
+ def __init__(
+ self,
+ input_dim,
+ output_dim,
+ glu_type="sigmoid",
+ bias_in_glu=True,
+ ):
+ super().__init__()
+ self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu)
+ self.glu_act = GLU(-1, glu_type)
+
+ def forward(self, x):
+ """GLULinear forward
+
+ Args:
+ x: torch.Tensor
+ inpute tensor.
+ """
+ x = self.linear(x)
+ return self.glu_act(x)
+
+
+class FeedForward(nn.Module):
+ """FeedForward Module.
+ For more details see Conformer paper:
+ https://arxiv.org/pdf/2005.08100.pdf
+
+ Args:
+ d_model: int
+ input size.
+ d_inner: int
+ output size.
+ dropout_rate: float,
+ dropout rate.
+ activation: str,
+ activation function name,
+ one of ["relu", "swish", "sigmoid"],
+ sigmoid activation is only used with "glu_in_fnn=True",
+ default "sigmoid".
+ bias_in_glu: bool, optional
+ """
+
+ def __init__(
+ self,
+ d_model,
+ d_inner,
+ dropout_rate,
+ activation="sigmoid",
+ bias_in_glu=True,
+ ):
+ super().__init__()
+ self.d_model = d_model
+ self.d_inner = d_inner
+
+ self.layer_norm = nn.LayerNorm(d_model)
+ module = GLULinear(d_model, d_inner, activation, bias_in_glu)
+ self.net = nn.Sequential(
+ module,
+ nn.Dropout(dropout_rate),
+ nn.Linear(d_inner, d_model),
+ nn.Dropout(dropout_rate),
+ )
+
+ def forward(self, x):
+ """FeedForward forward function.
+
+ Args:
+ x: torch.Tensor
+ input tensor.
+ """
+ out = self.net(self.layer_norm(x))
+
+ return out
+
+
+#### positional encoding starts here
+def _pre_hook(
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+):
+ """Perform pre-hook in load_state_dict for backward compatibility.
+
+ Note:
+ We saved self.pe until v.0.5.2 but we have omitted it later.
+ Therefore, we remove the item "pe" from `state_dict` for backward
+ compatibility.
+
+ """
+ k = prefix + "pe"
+ if k in state_dict:
+ state_dict.pop(k)
+
+
+class T5RelativeAttentionLogitBias(nn.Module):
+ """
+ This module implements the relative position bias described in Section
+ 2.1 of the T5 paper: https://arxiv.org/pdf/1910.10683.pdf
+
+ The Huggingface implementation is used as a reference
+ https://github.com/huggingface/transformers/blob/v4.30.0/src/
+ transformers/models/t5/modeling_t5.py#L435
+
+ Modifies attention as Q*K^T + B, where B is a learned scalar bias based
+ on relative position of the query and key. It is HxNxN, where H is the
+ number of heads, N is the sequence length.
+
+ I've made these modifications to the original T5 bias:
+ - Skipping of the bucketing step. Original T5 bias converted rel
+ position distances into logarithmically increasing buckets. This is
+ supposed to help with length generalization.
+ - I just directly use rel position index as bias values, as we don't
+ need length generalization (40s max is good enough for ASR encoder),
+ and it keeps ONNX export simple.
+ - I've also extended it so that biases can be asymmetric, the default
+ implementation treats L->R and R->L the same. Asymmetric was found to
+ yield better results in my experiments.
+
+ Args:
+ num_heads: int
+ Number of attention heads
+ num_buckets: int
+ Number of buckets to use for relative attention bias. This is the
+ size of the learnable bias parameter. Bucketing is not yet
+ supported, so this defaults to -1 which means no bucketing is
+ used (max_distance determines size of bias param).
+ max_distance: int
+ Maximum distance to use for relative attention bias. With
+ num_buckets=-1, this directly controls the max size of the bias
+ parameter. When num_buckets > 0 is supported, this will control
+ the maximum distance for logarithmic bucketing after which all
+ positions are in the same bucket.
+ symmetric: bool
+ Whether to use symmetric or asymmetric biases. symmetric=False uses
+ 2x number of bias params to distinguish L->R from R->L. This was
+ found to be better for the encoder.
+ """
+
+ def __init__(self,
+ num_heads,
+ num_buckets=-1,
+ max_distance=1000,
+ symmetric=False):
+ super().__init__()
+ self.num_heads = num_heads
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ self.symmetric = symmetric
+ self._skip_bucketing = self.num_buckets < 0
+ if self._skip_bucketing:
+ self.num_buckets = max_distance
+ else:
+ raise NotImplementedError(
+ "T5 attention bias with bucketed positions is not yet tested")
+ if not self.symmetric:
+ self.num_buckets *= 2
+ self.bias_values = nn.Embedding(self.num_buckets, self.num_heads)
+
+ def forward(self, x):
+ # instantiate bias compatible with shape of x
+ maxpos = x.size(1)
+ context_position = torch.arange(maxpos,
+ device=x.device,
+ dtype=torch.long)[:, None]
+ memory_position = torch.arange(maxpos,
+ device=x.device,
+ dtype=torch.long)[None, :]
+ relative_position = memory_position - context_position
+ # clipping to a maximum distance using ops that play well with ONNX
+ # export
+ relative_position = relative_position.masked_fill(
+ relative_position < -self.max_distance, -self.max_distance)
+ relative_position = relative_position.masked_fill(
+ relative_position > self.max_distance - 1, self.max_distance - 1)
+
+ # mapping from relative position to index in the bias parameter
+ if self._skip_bucketing:
+ bias_idx = relative_position
+ else:
+ bias_idx = self._bucket_relative_position(relative_position)
+ if self.symmetric:
+ bias_idx = bias_idx.abs()
+ else:
+ bias_idx += self.num_buckets // 2
+
+ t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H]
+ t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(
+ 0) # [1, H, L, L]
+
+ return t5_rel_att_bias
+
+ def _bucket_relative_position(self, relative_position):
+ # this is a placeholder (isn't tested, likely buggy) using HuggingFace
+ # implem as a reference this also needs to be extended to support
+ # asymmetric +/- ve positions
+ relative_buckets = 0
+ if not self.causal:
+ self.num_buckets //= 2
+ relative_buckets += (relative_position > 0).to(
+ torch.long) * self.num_buckets
+ relative_position = torch.abs(relative_position)
+ else:
+ relative_position = -torch.min(relative_position,
+ torch.zeros_like(relative_position))
+ # now relative_position is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = self.num_buckets // 2
+ is_small = relative_position < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in
+ # positions up to max_distance
+ relative_position_if_large = max_exact + (
+ torch.log(relative_position.float() / max_exact) /
+ math.log(self.max_distance / max_exact) *
+ (self.num_buckets - max_exact)).to(torch.long)
+ relative_position_if_large = torch.min(
+ relative_position_if_large,
+ torch.full_like(relative_position_if_large, self.num_buckets - 1),
+ )
+
+ relative_buckets += torch.where(is_small, relative_position,
+ relative_position_if_large)
+ return relative_buckets
+
+
+class AbsolutePositionalEncoding(nn.Module):
+ """Absolute Positional encoding module.
+ This module implement Absolute sinusoidal positional encoding
+ from: https://arxiv.org/pdf/1706.03762.pdf
+
+ Args:
+ d_model: int
+ Input embedding size.
+ dropout_rate: float
+ dropout rate
+ max_len: int, optional
+ Maximum input length sequence, Default 5000
+
+ """
+
+ def __init__(self, d_model, dropout_rate, max_len=5000):
+ """Construct an PositionalEncoding object."""
+ super().__init__()
+ self.d_model = d_model
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+ self._register_load_state_dict_pre_hook(_pre_hook)
+
+ def extend_pe(self, x):
+ """Reset the positional encodings.
+
+ Args:
+ x: torch.Tensor
+ """
+ if self.pe is not None and self.pe.size(1) >= x.size(1):
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ pe = torch.zeros(x.size(1), self.d_model)
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
+ -(math.log(10000.0) / self.d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+ def forward(self, x: torch.Tensor):
+ """Add positional encoding.
+
+ Args:
+ x: torch.Tensor
+ Input tensor. shape is (batch, time, ...)
+
+ Returns:
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
+
+ """
+ self.extend_pe(x)
+ x = x * self.xscale + self.pe[:, :x.size(1)]
+ return self.dropout(x)
+
+
+#### forward embedding layers starts here
+class MeanVarianceNormLayer(nn.Module):
+ """Mean/variance normalization layer.
+
+ Will subtract mean and multiply input by inverted standard deviation.
+ Typically used as a very first layer in a model.
+
+ Args:
+ input_size: int
+ layer input size.
+ """
+
+ def __init__(self, input_size):
+ super().__init__()
+ self.input_size = input_size
+ self.register_buffer("global_mean", torch.zeros(input_size))
+ self.register_buffer("global_invstd", torch.ones(input_size))
+ self.global_mean: Optional[Tensor]
+ self.global_invstd: Optional[Tensor]
+
+ def forward(self, input_: Tensor) -> Tensor:
+ """MeanVarianceNormLayer Forward
+
+ Args:
+ input_: torch.Tensor
+ input tensor.
+ """
+ return (input_ - self.global_mean) * self.global_invstd
+
+
+class CausalConv1D(nn.Conv1d):
+ """
+ A causal version of nn.Conv1d where each step would have limited access to
+ locations on its right or left
+ All arguments are the same as nn.Conv1d except padding.
+
+ If padding is set None, then paddings are set automatically to make it a
+ causal convolution where each location would not see any steps on its right.
+
+ If padding is set as a list (size of 2), then padding[0] would be used as
+ left padding and padding[1] as right padding.
+ It would make it possible to control the number of steps to be accessible
+ on the right and left.
+ This mode is not supported when stride > 1. padding[0]+padding[1] should
+ be equal to (kernel_size - 1).
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: Union[str, int] = 0,
+ dilation: int = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = "zeros",
+ device=None,
+ dtype=None,
+ ) -> None:
+ self.cache_drop_size = None
+ if padding is None:
+ self._left_padding = kernel_size - 1
+ self._right_padding = stride - 1
+ else:
+ if stride != 1 and padding != kernel_size - 1:
+ raise ValueError(
+ "No striding allowed for non-symmetric convolutions!")
+ if isinstance(padding, int):
+ self._left_padding = padding
+ self._right_padding = padding
+ elif (isinstance(padding, list) and len(padding) == 2
+ and padding[0] + padding[1] == kernel_size - 1):
+ self._left_padding = padding[0]
+ self._right_padding = padding[1]
+ else:
+ raise ValueError(f"Invalid padding param: {padding}!")
+
+ self._max_cache_len = self._left_padding
+
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=0,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ padding_mode=padding_mode,
+ device=device,
+ dtype=dtype,
+ )
+
+ def update_cache(self, x, cache=None):
+ if cache is None:
+ new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
+ next_cache = cache
+ else:
+ new_x = F.pad(x, pad=(0, self._right_padding))
+ new_x = torch.cat([cache, new_x], dim=-1)
+ if self.cache_drop_size > 0:
+ next_cache = new_x[:, :, :-self.cache_drop_size]
+ else:
+ next_cache = new_x
+ next_cache = next_cache[:, :, -cache.size(-1):]
+ return new_x, next_cache
+
+ def forward(self, x, cache=None):
+ x, cache = self.update_cache(x, cache=cache)
+ x = super().forward(x)
+ if cache is None:
+ return x
+ else:
+ return x, cache
+
+
+class CausalConv2D(nn.Conv2d):
+ """
+ A causal version of nn.Conv2d where each location in the 2D matrix would
+ have no access to locations on its right or down
+ All arguments are the same as nn.Conv2d except padding which should be
+ set as None
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: Union[str, int] = 0,
+ dilation: int = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = "zeros",
+ device=None,
+ dtype=None,
+ ) -> None:
+ if padding is not None:
+ raise ValueError(
+ "Argument padding should be set to None for CausalConv2D.")
+ self._left_padding = kernel_size - 1
+ self._right_padding = stride - 1
+
+ padding = 0
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ groups,
+ bias,
+ padding_mode,
+ device,
+ dtype,
+ )
+
+ def forward(
+ self,
+ x,
+ ):
+ if self.training:
+ x = F.pad(
+ x,
+ pad=(
+ self._left_padding,
+ self._right_padding,
+ self._left_padding,
+ self._right_padding,
+ ),
+ )
+ else:
+ x = F.pad(
+ x,
+ pad=(self._left_padding, self._right_padding, 0, 0),
+ )
+ x = super().forward(x)
+ return x
+
+
+class NemoConvSubsampling(torch.nn.Module):
+ """Convlutional subsampling module, taken from NeMo ASR
+ (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a
+ 34501479cf/nemo/collections/asr/parts/submodules/subsampling.py)
+
+ Striding Subsampling: "Speech-Transformer: A No-Recurrence
+ Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong
+ et al. (https://ieeexplore.ieee.org/document/8462506)
+
+
+ Compared with the EncoderConv2D (`input_layer: custom`), this is a
+ much simplified approach, and uses no LayerNorm and far fewer Conv2Ds.
+ Moreover, depthwise convolutions are used to reduce FLOPs, but the first
+ layer is kept as a regular convolution so as not to degrade accuracy.
+
+ `Striding` and `dw_striding` are the same except that the latter uses
+ depthwise convolutions after the first layer, whereas the former does not.
+
+ Args:
+ subsampling_factor (int): Time reduction factor
+ feat_in (int): size of the input features
+ feat_out (int): size of the output features
+ subsampling (str): The subsampling technique, choose from
+ {"striding", "dw-striding", "striding_conv1d",
+ "dw_striding_conv1d"}
+ conv_channels (int): Number of channels for the convolution layers,
+ default is 256.
+ subsampling_conv_chunking_factor (int): Input chunking factor which
+ can be -1 (no chunking) 1 (auto) or a power of 2. Default is 1
+ activation (Module): activation function, default is nn.ReLU()
+ is_causal (bool): whether to use causal Conv1/2D, where each step will
+ have limited access to locations on its right or left
+ """
+
+ def __init__(
+ self,
+ feat_in,
+ feat_out,
+ subsampling_factor=4,
+ subsampling="dw_striding",
+ conv_channels=256,
+ subsampling_conv_chunking_factor=1,
+ activation=nn.ReLU(), # noqa: B008
+ is_causal=False,
+ ):
+ super().__init__()
+ self._subsampling = subsampling
+ self._conv_channels = conv_channels
+ self._feat_in = feat_in
+ self._feat_out = feat_out
+
+ if subsampling_factor % 2 != 0:
+ raise ValueError("Sampling factor should be a multiply of 2!")
+ self._sampling_num = int(math.log(subsampling_factor, 2))
+ self.subsampling_factor = subsampling_factor
+ self.is_causal = is_causal
+ self.subsampling_causal_cond = subsampling in (
+ "dw_striding",
+ "striding",
+ "striding_conv1d",
+ )
+
+ if (subsampling_conv_chunking_factor != -1
+ and subsampling_conv_chunking_factor != 1
+ and subsampling_conv_chunking_factor % 2 != 0):
+ raise ValueError(
+ "subsampling_conv_chunking_factor should be -1, 1, or a "\
+ "power of 2"
+ )
+ self.subsampling_conv_chunking_factor = \
+ subsampling_conv_chunking_factor
+
+ in_channels = 1
+ layers = []
+
+ if subsampling == "dw_striding":
+ self._stride = 2
+ self._kernel_size = 3
+ self._ceil_mode = False
+
+ if self.is_causal:
+ self._left_padding = self._kernel_size - 1
+ self._right_padding = self._stride - 1
+ self._max_cache_len = subsampling_factor + 1
+ else:
+ self._left_padding = (self._kernel_size - 1) // 2
+ self._right_padding = (self._kernel_size - 1) // 2
+ self._max_cache_len = 0
+
+ # Layer 1
+ if self.is_causal:
+ layers.append(
+ CausalConv2D(
+ in_channels=in_channels,
+ out_channels=conv_channels,
+ kernel_size=self._kernel_size,
+ stride=self._stride,
+ padding=None,
+ ))
+ else:
+ layers.append(
+ torch.nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=conv_channels,
+ kernel_size=self._kernel_size,
+ stride=self._stride,
+ padding=self._left_padding,
+ ))
+ in_channels = conv_channels
+ layers.append(activation)
+
+ for i in range(self._sampling_num - 1):
+ if self.is_causal:
+ layers.append(
+ CausalConv2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=self._kernel_size,
+ stride=self._stride,
+ padding=None,
+ groups=in_channels,
+ ))
+ else:
+ layers.append(
+ torch.nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=self._kernel_size,
+ stride=self._stride,
+ padding=self._left_padding,
+ groups=in_channels,
+ ))
+
+ layers.append(
+ torch.nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=conv_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1,
+ ))
+ layers.append(activation)
+ in_channels = conv_channels
+
+ elif subsampling == "striding":
+ self._stride = 2
+ self._kernel_size = 3
+ self._ceil_mode = False
+
+ if self.is_causal:
+ self._left_padding = self._kernel_size - 1
+ self._right_padding = self._stride - 1
+ self._max_cache_len = subsampling_factor + 1
+ else:
+ self._left_padding = (self._kernel_size - 1) // 2
+ self._right_padding = (self._kernel_size - 1) // 2
+ self._max_cache_len = 0
+
+ for i in range(self._sampling_num):
+ if self.is_causal:
+ layers.append(
+ CausalConv2D(
+ in_channels=in_channels,
+ out_channels=conv_channels,
+ kernel_size=self._kernel_size,
+ stride=self._stride,
+ padding=None,
+ ))
+ else:
+ layers.append(
+ torch.nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=conv_channels,
+ kernel_size=self._kernel_size,
+ stride=self._stride,
+ padding=self._left_padding,
+ ))
+ layers.append(activation)
+ in_channels = conv_channels
+
+ elif subsampling == "striding_conv1d":
+ in_channels = feat_in
+
+ self._stride = 2
+ self._kernel_size = 5
+ self._ceil_mode = False
+
+ if self.is_causal:
+ self._left_padding = self._kernel_size - 1
+ self._right_padding = self._stride - 1
+ self._max_cache_len = subsampling_factor + 1
+ else:
+ self._left_padding = (self._kernel_size - 1) // 2
+ self._right_padding = (self._kernel_size - 1) // 2
+ self._max_cache_len = 0
+
+ for i in range(self._sampling_num):
+ if self.is_causal:
+ layers.append(
+ CausalConv1D(
+ in_channels=in_channels,
+ out_channels=(feat_out if self._sampling_num == i +
+ 1 else conv_channels),
+ kernel_size=self._kernel_size,
+ stride=self._stride,
+ padding=None,
+ ))
+ else:
+ layers.append(
+ torch.nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=(feat_out if self._sampling_num == i +
+ 1 else conv_channels),
+ kernel_size=self._kernel_size,
+ stride=self._stride,
+ padding=self._left_padding,
+ ))
+ layers.append(activation)
+ in_channels = conv_channels
+
+ elif subsampling == "dw_striding_conv1d":
+ in_channels = feat_in
+
+ self._stride = 2
+ self._kernel_size = 5
+ self._ceil_mode = False
+
+ self._left_padding = (self._kernel_size - 1) // 2
+ self._right_padding = (self._kernel_size - 1) // 2
+
+ # Layer 1
+ layers.extend([
+ torch.nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=self._kernel_size,
+ stride=self._stride,
+ padding=self._left_padding,
+ groups=in_channels,
+ ),
+ torch.nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=(feat_out if self._sampling_num == 1 else
+ conv_channels),
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1,
+ ),
+ ])
+ in_channels = conv_channels
+ layers.append(activation)
+
+ for i in range(self._sampling_num - 1):
+ layers.extend([
+ torch.nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=self._kernel_size,
+ stride=self._stride,
+ padding=self._left_padding,
+ groups=in_channels,
+ ),
+ torch.nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=(feat_out if self._sampling_num == i +
+ 2 else conv_channels),
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1,
+ ),
+ ])
+ layers.append(activation)
+ in_channels = conv_channels
+
+ else:
+ raise ValueError(f"Not valid sub-sampling: {subsampling}!")
+
+ if subsampling in ["dw_striding", "striding"]:
+ in_length = torch.tensor(feat_in, dtype=torch.float)
+ out_length = calc_length(
+ lengths=in_length,
+ all_paddings=self._left_padding + self._right_padding,
+ kernel_size=self._kernel_size,
+ stride=self._stride,
+ ceil_mode=self._ceil_mode,
+ repeat_num=self._sampling_num,
+ )
+ self.out = torch.nn.Linear(conv_channels * int(out_length),
+ feat_out)
+ self.conv2d_subsampling = True
+ elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
+ self.out = None
+ self.conv2d_subsampling = False
+ else:
+ raise ValueError(f"Not valid sub-sampling: {subsampling}!")
+
+ self.conv = torch.nn.Sequential(*layers)
+
+ def get_sampling_frames(self):
+ return [1, self.subsampling_factor]
+
+ def get_streaming_cache_size(self):
+ return [0, self.subsampling_factor + 1]
+
+ def forward(self, x, mask):
+ """
+ Forward method for NeMo subsampling.
+
+ Args:
+ x[Batch, Time, Filters]: torch.Tensor
+ input tensor
+ x_mask: torch.Tensor
+ input mask
+
+ Returns:
+ x: torch.Tensor
+ Resulting tensor from subsampling (B, T //
+ time_reduction_factor, feat_out)
+ pad_mask: torch.Tensor
+ tensor of padded hidden state sequences (B, 1, T //
+ time_reduction_factor)
+ """
+ x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2)
+
+ # split inputs if chunking_factor is set
+ if (self.subsampling_conv_chunking_factor != -1
+ and self.conv2d_subsampling):
+ if self.subsampling_conv_chunking_factor == 1:
+ # if subsampling_conv_chunking_factor is 1, we split only
+ # if needed.
+ # avoiding a bug / feature limiting indexing of tensors
+ # to 2**31.
+ # see https://github.com/pytorch/pytorch/issues/80020
+ x_ceil = (2**31 / self._conv_channels * self._stride *
+ self._stride)
+ need_to_split = torch.numel(x) > x_ceil
+ else:
+ # if subsampling_conv_chunking_factor > 1 we always split
+ need_to_split = True
+
+ if need_to_split:
+ x, success = self.conv_split_by_batch(x)
+ if not success: # if unable to split by batch, try by channel
+ if self._subsampling == "dw_striding":
+ x = self.conv_split_by_channel(x)
+ else:
+ x = self.conv(x) # try anyway
+ else:
+ x = self.conv(x)
+ else:
+ x = self.conv(x)
+
+ # Flatten Channel and Frequency Axes
+ if self.conv2d_subsampling:
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).reshape(b, t, -1))
+ # Transpose to Channel Last mode
+ else:
+ x = x.transpose(1, 2)
+
+ if mask is None:
+ return x, None
+
+ max_audio_length = x.shape[1]
+ feature_lens = mask.sum(1)
+ padding_length = torch.ceil(feature_lens / self.subsampling_factor)
+ if self.is_causal and self.subsampling_causal_cond:
+ feature_lens_remainder = feature_lens % self.subsampling_factor
+ padding_length[feature_lens_remainder != 1] += 1
+ pad_mask = torch.arange(0, max_audio_length, device=x.device).expand(
+ padding_length.size(0), -1) < padding_length.unsqueeze(1)
+ return x, pad_mask.unsqueeze(1)
+
+ def reset_parameters(self):
+ # initialize weights
+ if self._subsampling == "dw_striding":
+ with torch.no_grad():
+ # init conv
+ scale = 1.0 / self._kernel_size
+ dw_max = (self._kernel_size**2)**-0.5
+ pw_max = self._conv_channels**-0.5
+
+ torch.nn.init.uniform_(self.conv[0].weight, -scale, scale)
+ torch.nn.init.uniform_(self.conv[0].bias, -scale, scale)
+
+ for idx in range(2, len(self.conv), 3):
+ torch.nn.init.uniform_(self.conv[idx].weight, -dw_max,
+ dw_max)
+ torch.nn.init.uniform_(self.conv[idx].bias, -dw_max,
+ dw_max)
+ torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max,
+ pw_max)
+ torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max,
+ pw_max)
+
+ # init fc (80 * 64 = 5120 from https://github.com/kssteven418/
+ # Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/
+ # src/models/conformer_encoder.py#L487
+ fc_scale = (self._feat_out * self._feat_in /
+ self._sampling_num)**-0.5
+ torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale)
+ torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale)
+
+ def conv_split_by_batch(self, x):
+ """Tries to split input by batch, run conv and concat results"""
+ b, _, _, _ = x.size()
+ if b == 1: # can't split if batch size is 1
+ return x, False
+
+ if self.subsampling_conv_chunking_factor > 1:
+ cf = self.subsampling_conv_chunking_factor
+ else:
+ # avoiding a bug / feature limiting indexing of tensors to 2**31
+ # see https://github.com/pytorch/pytorch/issues/80020
+ x_ceil = 2**31 / self._conv_channels * self._stride * self._stride
+ p = math.ceil(math.log(torch.numel(x) / x_ceil, 2))
+ cf = 2**p
+
+ new_batch_size = b // cf
+ if new_batch_size == 0: # input is too big
+ return x, False
+
+ return (
+ torch.cat([
+ self.conv(chunk)
+ for chunk in torch.split(x, new_batch_size, 0)
+ ]),
+ True,
+ )
+
+ def conv_split_by_channel(self, x):
+ """For dw convs, tries to split input by time, run conv and concat
+ results"""
+ x = self.conv[0](x) # full conv2D
+ x = self.conv[1](x) # activation
+
+ for i in range(self._sampling_num - 1):
+ _, c, t, _ = x.size()
+
+ if self.subsampling_conv_chunking_factor > 1:
+ cf = self.subsampling_conv_chunking_factor
+ else:
+ # avoiding a bug / feature limiting indexing of tensors
+ # to 2**31
+ # see https://github.com/pytorch/pytorch/issues/80020
+ p = math.ceil(math.log(torch.numel(x) / 2**31, 2))
+ cf = 2**p
+
+ new_c = int(c // cf)
+ if new_c == 0:
+ new_c = 1
+
+ new_t = int(t // cf)
+ if new_t == 0:
+ new_t = 1
+
+ x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c,
+ x) # conv2D, depthwise
+
+ # splitting pointwise convs by time
+ x = torch.cat(
+ [
+ self.conv[i * 3 + 3](chunk)
+ for chunk in torch.split(x, new_t, 2)
+ ],
+ 2,
+ ) # conv2D, pointwise
+ x = self.conv[i * 3 + 4](x) # activation
+ return x
+
+ def channel_chunked_conv(self, conv, chunk_size, x):
+ """Performs channel chunked convolution"""
+
+ ind = 0
+ out_chunks = []
+ for chunk in torch.split(x, chunk_size, 1):
+ step = chunk.size()[1]
+
+ if self.is_causal:
+ chunk = nn.functional.pad(
+ chunk,
+ pad=(
+ self._kernel_size - 1,
+ self._stride - 1,
+ self._kernel_size - 1,
+ self._stride - 1,
+ ),
+ )
+ ch_out = nn.functional.conv2d(
+ chunk,
+ conv.weight[ind:ind + step, :, :, :],
+ bias=conv.bias[ind:ind + step],
+ stride=self._stride,
+ padding=0,
+ groups=step,
+ )
+ else:
+ ch_out = nn.functional.conv2d(
+ chunk,
+ conv.weight[ind:ind + step, :, :, :],
+ bias=conv.bias[ind:ind + step],
+ stride=self._stride,
+ padding=self._left_padding,
+ groups=step,
+ )
+ out_chunks.append(ch_out)
+ ind += step
+
+ return torch.cat(out_chunks, 1)
+
+ def change_subsampling_conv_chunking_factor(
+ self, subsampling_conv_chunking_factor: int):
+ if (subsampling_conv_chunking_factor != -1
+ and subsampling_conv_chunking_factor != 1
+ and subsampling_conv_chunking_factor % 2 != 0):
+ raise ValueError(
+ "subsampling_conv_chunking_factor should be -1, 1, or a "\
+ "power of 2"
+ )
+ self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
+
+
+def calc_length(lengths,
+ all_paddings,
+ kernel_size,
+ stride,
+ ceil_mode,
+ repeat_num=1):
+ """Calculates the output length of a Tensor passed through a convolution or
+ max pooling layer"""
+ add_pad: float = all_paddings - kernel_size
+ one: float = 1.0
+ for i in range(repeat_num):
+ lengths = (torch.div(lengths.to(dtype=torch.float) + add_pad, stride) +
+ one)
+ lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths)
+ return lengths.to(dtype=torch.int)
+
+
+#### multihead attention starts here
+class AttModule(nn.Module):
+ """Attention abstraction module"""
+
+ def __init__(self):
+ super().__init__()
+ self.export_mode = False
+
+ def set_export(self, mode=True):
+ """set the export mode"""
+ self.export_mode = mode
+
+ def forward(
+ self,
+ x: Tensor,
+ memory: Optional[Tensor] = None,
+ pos_emb: Optional[Tensor] = None,
+ att_mask: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
+ """AttModule forward
+
+ Args:
+ x: torch.Tensor
+ input tensor.
+ memory: torch.Tensor, optional
+ memory tensor.
+ pos_emb: torch.Tensor, optional
+ positional encoder embedding.
+ att_mask: torch.Tensor, optional
+ attention mask tensor.
+ """
+ return x, memory, pos_emb, att_mask
+
+
+class AttBlock(Block, AttModule):
+ """Attention Block module to support both Attention and Block module."""
+
+ def memory_dims(self, max_len=False):
+ """memory dimensions"""
+ return (1, self.input_size)
+
+
+def masked_softmax(
+ scores,
+ mask: Optional[Tensor],
+):
+ if mask is not None:
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
+ scores = scores.masked_fill(mask, -torch.inf)
+ attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0) # (batch, head, time1, time2)
+ else:
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+ return attn
+
+
+class MultiHeadedAttention(nn.Module):
+ """Multi-Head Attention layer with optional relative position embedding
+ and GLU.
+
+ Args:
+ n_head: int
+ the number of heads.
+ n_feat: int
+ input size features.
+ dropout_rate: float
+ dropout rate.
+ use_LN: bool
+ apply layer norm or not
+ dropout_at_output: bool
+ whether to apply dropout at output
+ attention_inner_dim: int, optional
+ the attention dimension used in the class,
+ it can be different from the input dimension n_feat.
+ default: -1 (equal to n_feat).
+ use_pt_scaled_dot_product_attention: bool, optional
+ if set True, use pytorch scaled dot product attention in training.
+ NOTE: this will NOT be used in ONNX decoding due to a lack of
+ support. In that case, we use the original attention
+ implementation, which shows no regression.
+ default: False.
+ n_value: int, optional
+ if set to values other than -1, use a different dimension for
+ value. With the default value (i.e. -1), it is backward compatible.
+ group_size: int, optional. must divide `n_head`
+ if group_size > 1: GQA
+ if group_size = 1: MHA
+ if group_size = n_head: MQA
+ """
+
+ inv_sqrt_d_k: torch.jit.Final[float]
+ h: torch.jit.Final[int]
+ h_k: torch.jit.Final[int]
+ g: torch.jit.Final[int]
+
+ def __init__(
+ self,
+ n_head,
+ n_feat,
+ dropout_rate,
+ attention_inner_dim=-1,
+ glu_type="swish",
+ bias_in_glu=True,
+ use_pt_scaled_dot_product_attention=False,
+ n_value=-1,
+ group_size: int = 1,
+ ):
+ super().__init__()
+ if n_value == -1:
+ n_value = n_feat
+ if attention_inner_dim == -1:
+ attention_inner_dim = n_feat
+ assert attention_inner_dim % n_head == 0
+
+ # We assume d_v always equals d_k
+ self.d_k = attention_inner_dim // n_head
+ self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k)
+ self.h = n_head
+ assert n_head % group_size == 0, "group_size must divide n_head"
+ self.g = group_size
+ self.h_k = n_head // group_size
+
+ self.linear_q = nn.Linear(n_feat, attention_inner_dim)
+ self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size)
+ self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size)
+ self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value)
+
+ self.attn = torch.jit.Attribute(None, Optional[Tensor])
+ self.dropout = nn.Dropout(p=dropout_rate)
+ self.dropout_rate = dropout_rate
+ self.use_pt_scaled_dot_product_attention = (
+ use_pt_scaled_dot_product_attention)
+
+ if use_pt_scaled_dot_product_attention and group_size > 1:
+ raise ValueError("Cannot use PT Scaled Attention with GQA")
+
+ # Torchscript eager quantization. Note that these functions below are
+ # NOOPs and have very little impact on performance unless quantization
+ # is enabled.
+ self.quant_q = torch.ao.quantization.QuantStub()
+ self.quant_x = torch.ao.quantization.QuantStub()
+ self.dequant = torch.ao.quantization.DeQuantStub()
+ self.ffunc = torch.ao.nn.quantized.FloatFunctional()
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ pos_k: Tensor,
+ pos_v: Tensor,
+ mask: Optional[Tensor],
+ relative_attention_bias: Optional[Tensor] = None,
+ ):
+ """Compute 'Scaled Dot Product Attention'.
+
+ Args:
+ query: torch.Tensor
+ query tensor (batch, time1, size)
+ key: torch.Tensor
+ key tensor (batch, time2, size)
+ value: torch.Tensor
+ value tensor (batch, time1, size)
+ pos_k: torch.Tensor
+ key tensor used for relative positional embedding.
+ pos_v: torch.Tensor
+ value tensor used for relative positional embedding.
+ mask: torch.Tensor
+ mask tensor (batch, time1, time2)
+ relative_attention_bias: torch.Tensor
+ bias added to attention logits w.r.t. relative positions
+ (1, n_head, time1, time2)
+ """
+ n_batch = query.size(0)
+
+ q = self.linear_q(query).view(n_batch, -1, self.h,
+ self.d_k) # (b, t, d)
+ k = self.linear_k(key).view(n_batch, -1, self.h_k,
+ self.d_k) # (b, t, d)
+ v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k)
+ q = (q.transpose(1, 2) if self.use_pt_scaled_dot_product_attention
+ and not torch.jit.is_scripting() else q.transpose(1, 2) *
+ self.inv_sqrt_d_k)
+ k = k.transpose(1, 2) # (batch, head_k, time2, d_k)
+ v = v.transpose(1, 2) # (batch, head_k, time2, d_k)
+
+ if (self.use_pt_scaled_dot_product_attention
+ and not torch.jit.is_scripting()):
+ attn_mask = None
+ if mask is not None:
+ mask = mask.unsqueeze(1)
+ if relative_attention_bias is not None:
+ attn_mask = mask + relative_attention_bias
+ else:
+ attn_mask = mask
+ if mask.dtype != q.dtype:
+ attn_mask = attn_mask.to(q.dtype)
+
+ with torch.backends.cuda.sdp_kernel(enable_flash=True,
+ enable_math=True,
+ enable_mem_efficient=True):
+ x = torch.nn.functional.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=attn_mask,
+ dropout_p=self.dropout_rate,
+ )
+ else:
+ if self.h != self.h_k:
+ q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k)
+ A = torch.einsum("b g h t d, b h s d -> b h t s", q, k)
+ else:
+ A = torch.matmul(q, k.transpose(-2, -1))
+ if pos_k is not None:
+ if self.h != self.h_k:
+ B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k)
+ else:
+ reshape_q = (q.contiguous().view(n_batch * self.h, -1,
+ self.d_k).transpose(0, 1)
+ ) # (t1,nh,dk)
+ B = torch.matmul(reshape_q,
+ pos_k.transpose(-2,
+ -1)) # pos_k: (t1,dk,t2)
+ B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0),
+ pos_k.size(1))
+ scores = A + B
+ else:
+ scores = A
+
+ if relative_attention_bias is not None:
+ scores = scores + relative_attention_bias
+
+ attn = masked_softmax(scores, mask) # (batch, head, time1, time2)
+
+ self.attn = attn
+
+ p_attn = self.dropout(attn)
+ x = torch.matmul(p_attn.to(v.dtype),
+ v) # (batch, head, time1, d_k)
+ if pos_v is not None:
+ reshape_attn = (p_attn.contiguous().view(
+ n_batch * self.h, pos_v.size(0),
+ pos_v.size(1)).transpose(0, 1)) # (t1, bh, t2)
+
+ attn_v = (torch.matmul(reshape_attn, pos_v).transpose(
+ 0, 1).contiguous().view(n_batch, self.h, pos_v.size(0),
+ self.d_k))
+ x = x + attn_v
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
+ self.h_k * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+
+def validate_checkpointing_config(activation_checkpointing):
+ """validate activation checkpointing configuration"""
+ if isinstance(activation_checkpointing, str):
+ assert activation_checkpointing in (
+ "",
+ "checkpoint",
+ "offload",
+ ), "activation_checkpointing has to be a dict or a str in "\
+ "('', 'checkpoint', 'offload')."
+ elif isinstance(activation_checkpointing, dict):
+ assert activation_checkpointing.get("module", "transformer") in (
+ "transformer",
+ "attention",
+ ), "module in activation_checkpointing has to be in "\
+ "('transformer', 'attention')."
+ else:
+ raise ValueError("activation_checkpointing has to be a str"\
+ " or dict.")
+
+
+def embedding_checkpoint_wrapper(
+ activation_checkpointing: Union[str, Dict], ) -> Callable:
+ """return encoder embedding activation checkpoint wrapper"""
+ validate_checkpointing_config(activation_checkpointing)
+
+ if isinstance(activation_checkpointing, str):
+ if activation_checkpointing:
+ if activation_checkpointing == "offload":
+ return offload_wrapper
+ return partial(checkpoint_wrapper)
+ return lambda x: x
+
+ if isinstance(activation_checkpointing, dict):
+ enabled = activation_checkpointing.get("embed", False)
+ if enabled:
+ offloading = activation_checkpointing.get("offload", False)
+ if offloading:
+ return offload_wrapper
+ impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get(
+ "reentrant", False) else CheckpointImpl.NO_REENTRANT)
+ return partial(checkpoint_wrapper, checkpoint_impl=impl)
+ return lambda x: x
+ raise ValueError("Invalid activation_checkpointing config")
+
+
+def attn_checkpointing(activation_checkpointing: Union[str, Dict],
+ i) -> Union[str, Dict]:
+ """return activation checkpointing config for attention layer"""
+ if isinstance(activation_checkpointing, str):
+ return ""
+
+ if isinstance(activation_checkpointing, dict):
+ target_layer_cls = activation_checkpointing.get(
+ "module", "transformer")
+ checkpointing_interval = activation_checkpointing.get("interval", 1)
+ if target_layer_cls == "attention" and i % checkpointing_interval == 0:
+ return activation_checkpointing
+ return ""
+
+ raise ValueError("Invalid activation_checkpointing config")
+
+
+class MultiSequential(torch.nn.Sequential):
+ """Multi-input multi-output torch.nn.Sequential"""
+
+ @torch.jit.ignore
+ def forward(self, *args):
+ """Forward method implementation."""
+ for m in self:
+ args = m(*args)
+ return args
+
+
+def repeat(repeat_num, module_gen_fn):
+ """repeat module N times
+
+ :param int repeat_num: repeat time
+ :param function module_gen_fn: function to generate module
+ :return: repeated modules
+ :rtype: MultiSequential
+ """
+ return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)])
+
+
+def get_offset(input_layer: str, time_reduction: int):
+ """Get an offset. We will use the offset for determining #frames of a
+ subsampled feature.
+
+ Args:
+ input_layer (str): Type of an input layer
+ time_reduction (int): time reduction factor for downsampling a feature
+ Returns:
+ int: offset
+ """
+ if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4:
+ return 3
+ if input_layer in ("conv2d", ) and time_reduction == 6:
+ return 1
+ if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8:
+ return 7
+ return 0
+
+
+def unfold_tensor(xs_pad, max_seq_len):
+ """
+ For a given tensor with shape of (N, T, D), if sequence length T is
+ longer than max_seq_len, this function unfold it to a
+ (NT', max_seq_len, D) where T' is T // max_seq_len.
+ Args:
+ xs_pad: N, T, D
+ """
+ _, _, D = xs_pad.shape
+ xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T
+ # N x D x 1 x T => N x (D x max_seq_len) x T'
+ xs_pad = F.unfold(
+ xs_pad[..., None, :],
+ kernel_size=(1, max_seq_len),
+ stride=(1, max_seq_len),
+ )
+ new_bsz, _, slen = xs_pad.shape
+ # N x D x max_seq_len x T'
+ xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen)
+ # N x T' x max_seq_len x D
+ xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous()
+ # NT' x max_seq_len x D
+ xs_pad = xs_pad.view(-1, max_seq_len, D)
+ return xs_pad
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 3a7fcdcf7b370..74160e2d9ee40 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -182,6 +182,7 @@
"Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
+ "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
# [Encoder-decoder]
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
diff --git a/vllm/model_executor/models/vision_siglip_navit.py b/vllm/model_executor/models/vision_siglip_navit.py
new file mode 100644
index 0000000000000..3a9597a845ff9
--- /dev/null
+++ b/vllm/model_executor/models/vision_siglip_navit.py
@@ -0,0 +1,1966 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Siglip model configuration"""
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Any, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn.init import _calculate_fan_in_and_fan_out
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
+from transformers.modeling_outputs import (BaseModelOutput,
+ BaseModelOutputWithPooling)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import (ModelOutput, add_start_docstrings,
+ add_start_docstrings_to_model_forward, logging,
+ replace_return_docstrings)
+
+from vllm.platforms import _Backend
+
+from .vision import get_vit_attn_backend
+
+logger = logging.get_logger(__name__)
+
+SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "google/siglip-base-patch16-224":
+ "https://huggingface.co/google/siglip-base-patch16-224/"\
+ "resolve/main/config.json",
+}
+
+
+class SiglipTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a
+ [`SiglipTextModel`]. It is used to instantiate a Siglip text encoder
+ according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar
+ configuration to that of the text encoder of the Siglip [google/
+ siglip-base-patch16-224](https://huggingface.co/google/siglip-base
+ -patch16-224) architecture.
+ Configuration objects inherit from [`PretrainedConfig`] and can be used
+ to control the model outputs. Read the documentation from
+ [`PretrainedConfig`] for more information.
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the Siglip text model. Defines the number of
+ different tokens that can be represented by the `inputs_ids`
+ passed when calling [`SiglipModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer
+ in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the
+ Transformer encoder.
+ max_position_embeddings (`int`, *optional*, defaults to 64):
+ The maximum sequence length that this model might ever be used
+ with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ hidden_act (`str` or `function`, *optional*, defaults to
+ `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the
+ encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ pad_token_id (`int`, *optional*, defaults to 1):
+ The id of the padding token in the vocabulary.
+ bos_token_id (`int`, *optional*, defaults to 49406):
+ The id of the beginning-of-sequence token in the vocabulary.
+ eos_token_id (`int`, *optional*, defaults to 49407):
+ The id of the end-of-sequence token in the vocabulary.
+ Example:
+ ```python
+ >>> from transformers import SiglipTextConfig, SiglipTextModel
+ >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224
+ style configuration
+ >>> configuration = SiglipTextConfig()
+ >>> # Initializing a SiglipTextModel (with random weights) from the
+ google/siglip-base-patch16-224 style configuration
+ >>> model = SiglipTextModel(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "siglip_text_model"
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=768,
+ intermediate_size=3072,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ max_position_embeddings=64,
+ hidden_act="gelu_pytorch_tanh",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ # This differs from `CLIPTokenizer`'s default and from openai/siglip
+ # See https://github.com/huggingface/transformers/pull/24773#
+ # issuecomment-1632287538
+ pad_token_id=1,
+ bos_token_id=49406,
+ eos_token_id=49407,
+ _flash_attn_2_enabled=True,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.max_position_embeddings = max_position_embeddings
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.attention_dropout = attention_dropout
+ self._flash_attn_2_enabled = _flash_attn_2_enabled
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str,
+ os.PathLike],
+ **kwargs) -> "PretrainedConfig":
+ cls._set_token_in_kwargs(kwargs)
+
+ config_dict, kwargs = cls.get_config_dict(
+ pretrained_model_name_or_path, **kwargs)
+
+ # get the text config dict if we are loading from SiglipConfig
+ if config_dict.get("model_type") == "siglip":
+ config_dict = config_dict["text_config"]
+
+ if "model_type" in config_dict and hasattr(
+ cls,
+ "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ "You are using a model of type %s to instantiate a model of "
+ "type %s. This is not supported for all configurations of "
+ "models and can yield errors.", config_dict['model_type'],
+ cls.model_type)
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class SiglipVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a
+ [`SiglipVisionModel`]. It is used to instantiate a
+ Siglip vision encoder according to the specified arguments, defining the
+ model architecture. Instantiating a configuration with the defaults will
+ yield a similar configuration to that of the vision encoder of the Siglip
+ [google/siglip-base-patch16-224](https://huggingface.co/google/
+ siglip-base-patch16-224) architecture.
+ Configuration objects inherit from [`PretrainedConfig`] and can be used
+ to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer
+ in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the
+ Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input images.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to
+ `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the
+ encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and
+ `"gelu_new"` ``"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ Example:
+ ```python
+ >>> from transformers import SiglipVisionConfig, SiglipVisionModel
+ >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224
+ style configuration
+ >>> configuration = SiglipVisionConfig()
+ >>> # Initializing a SiglipVisionModel (with random weights) from the
+ google/siglip-base-patch16-224 style configuration
+ >>> model = SiglipVisionModel(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "siglip_vision_model"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ intermediate_size=3072,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ num_channels=3,
+ image_size=224,
+ patch_size=16,
+ hidden_act="gelu_pytorch_tanh",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ _flash_attn_2_enabled=True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self._flash_attn_2_enabled = _flash_attn_2_enabled
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str,
+ os.PathLike],
+ **kwargs) -> "PretrainedConfig":
+ cls._set_token_in_kwargs(kwargs)
+
+ config_dict, kwargs = cls.get_config_dict(
+ pretrained_model_name_or_path, **kwargs)
+
+ # get the vision config dict if we are loading from SiglipConfig
+ if config_dict.get("model_type") == "siglip":
+ config_dict = config_dict["vision_config"]
+
+ if "model_type" in config_dict and hasattr(
+ cls,
+ "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ "You are using a model of type %s to "
+ "instantiate a model of type %s. This is not"
+ " supported for all configurations of models and can yield"
+ " errors.", config_dict['model_type'], cls.model_type)
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class SiglipConfig(PretrainedConfig):
+ r"""
+ [`SiglipConfig`] is the configuration class to store the configuration of a
+ [`SiglipModel`]. It is used to instantiate a Siglip model according to the
+ specified arguments, defining the text model and vision model configs.
+ Instantiating a configuration with the defaults will yield a similar
+ configuration to that of the Siglip [google/siglip-base-patch16-224](
+ https://huggingface.co/google/siglip-base-patch16-224) architecture.
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to
+ control the model outputs. Read the documentation from
+ [`PretrainedConfig`] for more information.
+ Args:
+ text_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize
+ [`SiglipTextConfig`].
+ vision_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize
+ [`SiglipVisionConfig`].
+ kwargs (*optional*):
+ Dictionary of keyword arguments.
+ Example:
+ ```python
+ >>> from transformers import SiglipConfig, SiglipModel
+ >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224
+ style configuration
+ >>> configuration = SiglipConfig()
+ >>> # Initializing a SiglipModel (with random weights) from the
+ google/siglip-base-patch16-224 style configuration
+ >>> model = SiglipModel(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ >>> # We can also initialize a SiglipConfig from a SiglipTextConfig
+ and a SiglipVisionConfig
+ >>> from transformers import SiglipTextConfig, SiglipVisionConfig
+ >>> # Initializing a SiglipText and SiglipVision configuration
+ >>> config_text = SiglipTextConfig()
+ >>> config_vision = SiglipVisionConfig()
+ >>> config = SiglipConfig.from_text_vision_configs(config_text,
+ config_vision)
+ ```"""
+
+ model_type = "siglip"
+
+ def __init__(self, text_config=None, vision_config=None, **kwargs):
+ super().__init__(**kwargs)
+
+ if text_config is None:
+ text_config = {}
+ logger.info(
+ "`text_config` is `None`. Initializing the `SiglipTextConfig`"
+ " with default values.")
+
+ if vision_config is None:
+ vision_config = {}
+ logger.info("`vision_config` is `None`. initializing the "
+ "`SiglipVisionConfig` with default values.")
+
+ self.text_config = SiglipTextConfig(**text_config)
+ self.vision_config = SiglipVisionConfig(**vision_config)
+
+ self.initializer_factor = 1.0
+
+ @classmethod
+ def from_text_vision_configs(cls, text_config: SiglipTextConfig,
+ vision_config: SiglipVisionConfig, **kwargs):
+ r"""
+ Instantiate a [`SiglipConfig`] (or a derived class) from siglip text
+ model configuration and siglip vision
+ model configuration.
+ Returns:
+ [`SiglipConfig`]: An instance of a configuration object
+ """
+
+ return cls(text_config=text_config.to_dict(),
+ vision_config=vision_config.to_dict(),
+ **kwargs)
+
+
+# coding=utf-8
+# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Siglip model."""
+
+_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
+
+SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "google/siglip-base-patch16-224",
+ # See all SigLIP models at https://huggingface.co/models?filter=siglip
+]
+
+
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+def _trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official
+ # releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/
+ # truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2,
+ )
+
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std) # noqa
+ u = norm_cdf((b - mean) / std) # noqa
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ if tensor.dtype in [torch.float16, torch.bfloat16]:
+ # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
+ og_dtype = tensor.dtype
+ tensor = tensor.to(torch.float32)
+ tensor.erfinv_()
+ tensor = tensor.to(og_dtype)
+ else:
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.0))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ if tensor.dtype == torch.float16:
+ # The `clamp_` op is not (yet?) defined in float16+cpu
+ tensor = tensor.to(torch.float32)
+ tensor.clamp_(min=a, max=b)
+ tensor = tensor.to(torch.float16)
+ else:
+ tensor.clamp_(min=a, max=b)
+
+
+def trunc_normal_tf_(tensor: torch.Tensor,
+ mean: float = 0.0,
+ std: float = 1.0,
+ a: float = -2.0,
+ b: float = 2.0) -> torch.Tensor:
+ """Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \\leq \text{mean} \\leq b`.
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where
+ the bounds [a, b] are applied when sampling the normal distribution with
+ mean=0, std=1.0 and the result is subsequently scaled and shifted by the
+ mean and std args.
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ """
+ with torch.no_grad():
+ _trunc_normal_(tensor, 0, 1.0, a, b)
+ tensor.mul_(std).add_(mean)
+
+
+def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+ if mode == "fan_in":
+ denom = fan_in
+ elif mode == "fan_out":
+ denom = fan_out
+ elif mode == "fan_avg":
+ denom = (fan_in + fan_out) / 2
+
+ variance = scale / denom
+
+ if distribution == "truncated_normal":
+ # constant is stddev of standard normal truncated to (-2, 2)
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
+ elif distribution == "normal":
+ with torch.no_grad():
+ tensor.normal_(std=math.sqrt(variance))
+ elif distribution == "uniform":
+ bound = math.sqrt(3 * variance)
+ with torch.no_grad():
+ tensor.uniform_(-bound, bound)
+ else:
+ raise ValueError(f"invalid distribution {distribution}")
+
+
+def lecun_normal_(tensor):
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
+
+
+def default_flax_embed_init(tensor):
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
+
+
+@dataclass
+# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with
+# CLIP->Siglip
+class SiglipVisionModelOutput(ModelOutput):
+ """
+ Base class for vision model's outputs that also contains image embeddings
+ of the pooling of the last hidden states.
+ Args:
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`
+ *optional* returned when model is initialized with
+ `with_projection=True`):
+ The image embeddings obtained by applying the projection layer to
+ the pooler_output.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size,
+ sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the
+ model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when
+ `output_hidden_states=True` is passed or when
+ `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings,
+ if the model has an embedding layer, + one for the output of each
+ layer) of shape `(batch_size, sequence_length, hidden_size)`.
+ Hidden-states of the model at the output of each layer plus the
+ optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when
+ `output_attentions=True` is passed or when
+ `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape
+ `(batch_size, num_heads, sequence_length, sequence_length)`.
+ Attentions weights after the attention softmax, used to compute the
+ weighted average in the self-attention heads.
+ """
+
+ image_embeds: Optional[torch.FloatTensor] = None
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with
+# CLIP->Siglip
+class SiglipTextModelOutput(ModelOutput):
+ """
+ Base class for text model's outputs that also contains a pooling of the
+ last hidden states.
+ Args:
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`
+ *optional* returned when model is initialized with
+ `with_projection=True`):
+ The text embeddings obtained by applying the projection layer to
+ model.
+ the pooler_output.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size,
+ sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when
+ `output_hidden_states=True` is passed or when
+ `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the
+ embeddings, if the model has an embedding layer, + one for the
+ output of each layer) of shape `(batch_size, sequence_length,
+ hidden_size)`.
+ Hidden-states of the model at the output of each layer plus the
+ optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when
+ `output_attentions=True` is passed or when
+ `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape
+ `(batch_size, num_heads, sequence_length, sequence_length)`.
+ Attentions weights after the attention softmax, used to compute
+ the weighted average in the self-attention heads.
+ """
+
+ text_embeds: Optional[torch.FloatTensor] = None
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+# Copied from transformers.models.clip.modeling_clip.CLIPOutput with
+# CLIP->Siglip
+class SiglipOutput(ModelOutput):
+ """
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when
+ `return_loss` is `True`):
+ Contrastive loss for image-text similarity.
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size,
+ text_batch_size)`):
+ The scaled dot product scores between `image_embeds` and
+ `text_embeds`. This represents the image-text similarity scores.
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size,
+ image_batch_size)`):
+ The scaled dot product scores between `text_embeds` and
+ `image_embeds`. This represents the text-image similarity scores.
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The text embeddings obtained by applying the projection layer to
+ the pooled output of [`SiglipTextModel`].
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The image embeddings obtained by applying the projection layer to
+ the pooled output of [`SiglipVisionModel`].
+ text_model_output(`BaseModelOutputWithPooling`):
+ The output of the [`SiglipTextModel`].
+ vision_model_output(`BaseModelOutputWithPooling`):
+ The output of the [`SiglipVisionModel`].
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits_per_image: torch.FloatTensor = None
+ logits_per_text: torch.FloatTensor = None
+ text_embeds: torch.FloatTensor = None
+ image_embeds: torch.FloatTensor = None
+ text_model_output: BaseModelOutputWithPooling = None
+ vision_model_output: BaseModelOutputWithPooling = None
+
+ def to_tuple(self) -> Tuple[Any]:
+ return tuple(
+ self[k] if k not in ["text_model_output", "vision_model_output"
+ ] else getattr(self, k).to_tuple()
+ for k in self.keys())
+
+
+class SiglipVisionEmbeddings(nn.Module):
+
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+
+ self.num_patches_per_side = self.image_size // self.patch_size
+ self.num_patches = self.num_patches_per_side**2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions,
+ self.embed_dim)
+
+ def forward(self, pixel_values: torch.FloatTensor,
+ patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
+ batch_size = pixel_values.size(0)
+
+ patch_embeds = self.patch_embedding(pixel_values)
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
+ max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, \
+ max_im_w // self.patch_size
+ boundaries = torch.arange(1 / self.num_patches_per_side, 1.0,
+ 1 / self.num_patches_per_side)
+ position_ids = torch.full(
+ size=(
+ batch_size,
+ max_nb_patches_h * max_nb_patches_w,
+ ),
+ fill_value=0,
+ )
+
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
+ nb_patches_h = p_attn_mask[:, 0].sum()
+ nb_patches_w = p_attn_mask[0].sum()
+
+ fractional_coords_h = torch.linspace(0, 1 - 1 / nb_patches_h,
+ nb_patches_h)
+ fractional_coords_w = torch.linspace(0, 1 - 1 / nb_patches_w,
+ nb_patches_w)
+
+ bucket_coords_h = torch.bucketize(fractional_coords_h,
+ boundaries,
+ right=True)
+ bucket_coords_w = torch.bucketize(fractional_coords_w,
+ boundaries,
+ right=True)
+
+ pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side +
+ bucket_coords_w).flatten()
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
+
+ position_ids = position_ids.to(self.position_embedding.weight.device)
+
+ embeddings = embeddings + self.position_embedding(position_ids)
+ return embeddings
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with
+# CLIP->Siglip
+class SiglipTextEmbeddings(nn.Module):
+
+ def __init__(self, config: SiglipTextConfig):
+ super().__init__()
+ embed_dim = config.hidden_size
+
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
+ self.position_embedding = nn.Embedding(config.max_position_embeddings,
+ embed_dim)
+
+ # position_ids (1, len position emb) is contiguous in memory and
+ # exported when serialized
+ self.register_buffer(
+ "position_ids",
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
+ persistent=False)
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ seq_length = input_ids.shape[
+ -1] if input_ids is not None else inputs_embeds.shape[-2]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.token_embedding(input_ids)
+
+ position_embeddings = self.position_embedding(position_ids)
+ embeddings = inputs_embeds + position_embeddings
+
+ return embeddings
+
+
+class SiglipAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`:"
+ f" {self.embed_dim} and `num_heads`: {self.num_heads}).")
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+ Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(batch_size, q_len, self.num_heads,
+ self.head_dim).transpose(1, 2)
+ key_states = key_states.view(batch_size, q_len, self.num_heads,
+ self.head_dim).transpose(1, 2)
+ value_states = value_states.view(batch_size, q_len, self.num_heads,
+ self.head_dim).transpose(1, 2)
+
+ k_v_seq_len = key_states.shape[-2]
+ attn_weights = torch.matmul(query_states, key_states.transpose(
+ 2, 3)) * self.scale
+
+ if attn_weights.size() != (batch_size, self.num_heads, q_len,
+ k_v_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size "
+ f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
+ f" {attn_weights.size()}")
+
+ if attention_mask is not None:
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
+ raise ValueError(f"Attention mask should be of size "
+ f"{(batch_size, 1, q_len, k_v_seq_len)}, "
+ f"but is {attention_mask.size()}")
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights,
+ dim=-1,
+ dtype=torch.float32).to(
+ query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights,
+ p=self.dropout,
+ training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (batch_size, self.num_heads, q_len,
+ self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size "
+ f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}")
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class SiglipFlashAttention2(SiglipAttention):
+ """
+ Llama flash attention module. This module inherits from `LlamaAttention` as
+ the weights of the module stays untouched. The only required change would
+ be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any
+ of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.is_causal = False # Hack to make sure we don't use a causal mask
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+ Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads,
+ self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_heads,
+ self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_heads,
+ self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value.get_usable_length(
+ kv_seq_len, self.layer_idx)
+
+ # TODO: These transpose are quite inefficient but Flash Attention
+ # requires the layout [batch_size, sequence_length, num_heads,
+ # head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training
+ # stability reasons therefore the input hidden states gets silently
+ # casted in float32. Hence, we need cast them back in the correct
+ # dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to
+ # not cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ "The input hidden states seems to be silently casted in "
+ "float32, this might be related to the fact you have upcasted "
+ "embedding or layer norm layers in float32. We will cast "
+ f"back the input in {target_dtype}.")
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = self._flash_attention_forward(query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=dropout_rate)
+
+ attn_output = attn_output.reshape(bsz, q_len,
+ self.embed_dim).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+ def _flash_attention_forward(self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ query_length,
+ dropout=0.0,
+ softmax_scale=None):
+ """
+ Calls the forward method of Flash Attention - if the input hidden
+ states contain at least one padding token first unpad the input,
+ then computes the attention scores and pad the final attention
+ scores.
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size
+ `(batch_size, seq_len)` where 0 stands for the position
+ of padding tokens and 1 for the position of non-padding
+ tokens.
+ dropout (`int`, *optional*):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 /
+ sqrt(head_dim)
+ """
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import pad_input # noqa
+
+ # TODO: Remove the `query_length != 1` check once Flash Attention for
+ # RoCm is bumped to 2.1. For details, please see the comment in
+ # LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, \
+ max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, attention_mask,
+ query_length)
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
+ query_length)
+ else:
+ attn_output = flash_attn_func(query_states,
+ key_states,
+ value_states,
+ dropout,
+ softmax_scale=softmax_scale,
+ causal=causal)
+
+ return attn_output
+
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
+ query_length):
+ from flash_attn.bert_padding import index_first_axis, unpad_input
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
+ attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
+ head_dim), indices_k)
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
+ head_dim), indices_k)
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
+ head_dim), indices_k)
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = \
+ unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
+class SiglipMLP(nn.Module):
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with
+# CLIP->Siglip
+class SiglipEncoderLayer(nn.Module):
+
+ def __init__(self, config: SiglipConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = (SiglipAttention(config) if
+ not getattr(config, "_flash_attn_2_enabled", False)
+ else SiglipFlashAttention2(config))
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim,
+ eps=config.layer_norm_eps)
+ self.mlp = SiglipMLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim,
+ eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
+ attention_mask (`torch.FloatTensor`):
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where
+ padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the attentions tensors of all
+ attention layers. See `attentions` under returned tensors for
+ more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states, )
+
+ if output_attentions:
+ outputs += (attn_weights, )
+
+ return outputs
+
+
+class SiglipPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface
+ for downloading and loading pretrained models.
+ """
+
+ config_class = SiglipConfig
+ base_model_prefix = "siglip"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+
+ if isinstance(module, SiglipVisionEmbeddings):
+ width = (self.config.vision_config.hidden_size if isinstance(
+ self.config, SiglipConfig) else self.config.hidden_size)
+ nn.init.normal_(module.position_embedding.weight,
+ std=1 / np.sqrt(width))
+ elif isinstance(module, nn.Embedding):
+ default_flax_embed_init(module.weight)
+ elif isinstance(module, SiglipAttention):
+ nn.init.normal_(module.q_proj.weight)
+ nn.init.normal_(module.k_proj.weight)
+ nn.init.normal_(module.v_proj.weight)
+ nn.init.normal_(module.out_proj.weight)
+ nn.init.zeros_(module.q_proj.bias)
+ nn.init.zeros_(module.k_proj.bias)
+ nn.init.zeros_(module.v_proj.bias)
+ nn.init.zeros_(module.out_proj.bias)
+ elif isinstance(module, SiglipMLP):
+ nn.init.normal_(module.fc1.weight)
+ nn.init.normal_(module.fc2.weight)
+ nn.init.normal_(module.fc1.bias, std=1e-6)
+ nn.init.normal_(module.fc2.bias, std=1e-6)
+ elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
+ nn.init.normal_(module.probe.data)
+ nn.init.normal_(module.attention.in_proj_weight.data)
+ nn.init.zeros_(module.attention.in_proj_bias.data)
+ elif isinstance(module, SiglipModel):
+ logit_scale_init = torch.tensor(0.0)
+ module.logit_scale.data.fill_(logit_scale_init)
+ module.logit_bias.data.zero_()
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
+ lecun_normal_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+SIGLIP_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass
+ documentation for the generic methods the library implements for all
+ its model (such as downloading or saving, resizing the input embeddings,
+ pruning heads etc.)
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/
+ stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation
+ for all matter related to general usage and behavior.
+ Parameters:
+ config ([`SiglipConfig`]): Model configuration class with all the
+ parameters of the model.
+ Initializing with a config file does not load the weights
+ associated with the model, only the configuration. Check out
+ the [`~PreTrainedModel.from_pretrained`] method to load the
+ model weights.
+"""
+
+SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)
+ `):
+ Indices of input sequence tokens in the vocabulary. Padding will
+ be ignored by default should you provide it.
+ Indices can be obtained using [`AutoTokenizer`]. See
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
+ for details. [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size,
+ sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask
+ values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `(batch_size,
+ sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position
+ embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+ [What are position IDs?](../glossary#position-ids)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention
+ layers. See `attentions` under returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See
+ `hidden_states` under returned tensors for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a
+ plain tuple.
+"""
+
+SIGLIP_VISION_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size,
+ num_channels, height, width)`):
+ Pixel values. Padding will be ignored by default should you
+ provide it. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`]
+ for details.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention
+ layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See
+ `hidden_states` under returned tensors for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a
+ plain tuple.
+"""
+
+SIGLIP_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size,
+ sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding
+ will be ignored by default should you provide it.
+ Indices can be obtained using [`AutoTokenizer`]. See
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
+ for details. [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`
+ , *optional*):
+ Mask to avoid performing attention on padding token indices. Mask
+ values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `(batch_size,
+ sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position
+ embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+ [What are position IDs?](../glossary#position-ids)
+ pixel_values (`torch.FloatTensor` of shape `(batch_size,
+ num_channels, height, width)`):
+ Pixel values. Padding will be ignored by default should you
+ provide it. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`]
+ for details.
+ return_loss (`bool`, *optional*):
+ Whether or not to return the contrastive loss.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention
+ layers. See `attentions` under returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See
+ `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a
+ plain tuple.
+"""
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with
+# CLIP->Siglip
+class SiglipEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers`
+ self attention layers. Each layer is a [`SiglipEncoderLayer`].
+ Args:
+ config: SiglipConfig
+ """
+
+ def __init__(self, config: SiglipConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([
+ SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)
+ ])
+ self.gradient_checkpointing = False
+
+ # Ignore copy
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size,
+ sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to
+ directly pass an embedded representation.
+ This is useful if you want more control over how to convert
+ `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size,
+ sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices.
+ Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all
+ attention layers. See `attentions` under returned tensors for
+ more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See
+ `hidden_states` under returned tensors for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a
+ plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions \
+ is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else
+ self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else \
+ self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states, )
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1], )
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states, )
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, encoder_states, all_attentions]
+ if v is not None)
+ return BaseModelOutput(last_hidden_state=hidden_states,
+ hidden_states=encoder_states,
+ attentions=all_attentions)
+
+
+class SiglipTextTransformer(nn.Module):
+
+ def __init__(self, config: SiglipTextConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+ self.embeddings = SiglipTextEmbeddings(config)
+ self.encoder = SiglipEncoder(config)
+ self.final_layer_norm = nn.LayerNorm(embed_dim,
+ eps=config.layer_norm_eps)
+
+ self.head = nn.Linear(embed_dim, embed_dim)
+
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling,
+ config_class=SiglipTextConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+ """
+ output_attentions = output_attentions if output_attentions \
+ is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states \
+ is not None else
+ self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else \
+ self.config.use_return_dict
+
+ if input_ids is None:
+ raise ValueError("You have to specify input_ids")
+
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+
+ hidden_states = self.embeddings(input_ids=input_ids,
+ position_ids=position_ids)
+
+ # note: SigLIP's text model does not use a causal mask, unlike the
+ # original CLIP model.
+ # expand attention_mask
+ if attention_mask is not None:
+ # [batch_size, seq_len] ->
+ # [batch_size, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(
+ attention_mask, hidden_states.dtype)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
+
+ # Assuming "sticky" EOS tokenization, last token is always EOS.
+ pooled_output = last_hidden_state[:, -1, :]
+ pooled_output = self.head(pooled_output)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """The text model from SigLIP without any head or projection on top.""",
+ SIGLIP_START_DOCSTRING,
+)
+class SiglipTextModel(SiglipPreTrainedModel):
+ config_class = SiglipTextConfig
+
+ _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
+
+ def __init__(self, config: SiglipTextConfig):
+ super().__init__(config)
+ self.text_model = SiglipTextTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.text_model.embeddings.token_embedding
+
+ def set_input_embeddings(self, value):
+ self.text_model.embeddings.token_embedding = value
+
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling,
+ config_class=SiglipTextConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+ Examples:
+ ```python
+ >>> from transformers import AutoTokenizer, SiglipTextModel
+ >>> model = SiglipTextModel.
+ from_pretrained("google/siglip-base-patch16-224")
+ >>> tokenizer = AutoTokenizer.
+ from_pretrained("google/siglip-base-patch16-224")
+ >>> # important: make sure to set padding="max_length"
+ as that's how the model was trained
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"],
+ padding="max_length", return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token)
+ states
+ ```"""
+ return_dict = return_dict if return_dict is not None else \
+ self.config.use_return_dict
+
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+class SiglipVisionTransformer(nn.Module):
+
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = SiglipVisionEmbeddings(config)
+ self.encoder = SiglipEncoder(config)
+ self.post_layernorm = nn.LayerNorm(embed_dim,
+ eps=config.layer_norm_eps)
+ self.head = SiglipMultiheadAttentionPoolingHead(config)
+
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling,
+ config_class=SiglipVisionConfig)
+ def forward(
+ self,
+ pixel_values,
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+ """
+ output_attentions = output_attentions if output_attentions is not None\
+ else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else
+ self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None \
+ else self.config.use_return_dict
+
+ batch_size = pixel_values.size(0)
+ if patch_attention_mask is None:
+ patch_attention_mask = torch.ones(
+ size=(
+ batch_size,
+ pixel_values.size(2) // self.config.patch_size,
+ pixel_values.size(3) // self.config.patch_size,
+ ),
+ dtype=torch.bool,
+ device=pixel_values.device,
+ )
+
+ hidden_states = self.embeddings(
+ pixel_values=pixel_values,
+ patch_attention_mask=patch_attention_mask)
+
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
+ # The call to `_upad_input` in `_flash_attention_forward` is expensive
+ # So when the `patch_attention_mask` is full of 1s (i.e. attending
+ # to the whole sequence), avoiding passing the attention_mask, which
+ # is equivalent to attending to the full sequence
+ if not torch.any(~patch_attention_mask):
+ attention_mask = None
+ else:
+ attention_mask = (_prepare_4d_attention_mask(
+ patch_attention_mask, hidden_states.dtype)
+ if not self.config._flash_attn_2_enabled else
+ patch_attention_mask)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ pooled_output = self.head(
+ hidden_state=last_hidden_state,
+ attention_mask=patch_attention_mask,
+ )
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class SiglipMultiheadAttentionPoolingHead(nn.Module):
+ """Multihead Attention Pooling."""
+
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__()
+
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.attention = torch.nn.MultiheadAttention(
+ config.hidden_size, config.num_attention_heads, batch_first=True)
+ self.layernorm = nn.LayerNorm(config.hidden_size,
+ eps=config.layer_norm_eps)
+ self.mlp = SiglipMLP(config)
+
+ def forward(self, hidden_state, attention_mask):
+ batch_size = hidden_state.shape[0]
+ probe = self.probe.repeat(batch_size, 1, 1)
+
+ hidden_state = self.attention(query=probe,
+ key=hidden_state,
+ value=hidden_state,
+ key_padding_mask=~attention_mask)[0]
+
+ residual = hidden_state
+ hidden_state = self.layernorm(hidden_state)
+ hidden_state = residual + self.mlp(hidden_state)
+
+ return hidden_state[:, 0]
+
+
+@add_start_docstrings(
+ """The vision model from SigLIP without any head or projection on top.""",
+ SIGLIP_START_DOCSTRING,
+)
+class SiglipVisionModel(SiglipPreTrainedModel):
+ config_class = SiglipVisionConfig
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__(config)
+
+ self.vision_model = SiglipVisionTransformer(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.vision_model.embeddings.patch_embedding
+
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling,
+ config_class=SiglipVisionConfig)
+ def forward(
+ self,
+ pixel_values,
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+ Examples:
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, SiglipVisionModel
+ >>> model = SiglipVisionModel.from_pretrained(
+ "google/siglip-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained(
+ "google/siglip-base-patch16-224")
+ >>> url =
+ "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> inputs = processor(images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled features
+ ```"""
+ return_dict = return_dict if return_dict is not None \
+ else self.config.use_return_dict
+
+ return self.vision_model(
+ pixel_values=pixel_values,
+ patch_attention_mask=patch_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+@add_start_docstrings(SIGLIP_START_DOCSTRING)
+class SiglipModel(SiglipPreTrainedModel):
+ config_class = SiglipConfig
+
+ def __init__(self, config: SiglipConfig):
+ super().__init__(config)
+
+ if not isinstance(config.text_config, SiglipTextConfig):
+ raise ValueError("config.text_config is expected to be of type "
+ f"SiglipTextConfig but is of type"
+ f" {type(config.text_config)}.")
+
+ if not isinstance(config.vision_config, SiglipVisionConfig):
+ raise ValueError("config.vision_config is expected to be of type "
+ "SiglipVisionConfig but is of type"
+ f" {type(config.vision_config)}.")
+
+ text_config = config.text_config
+ vision_config = config.vision_config
+
+ self.text_model = SiglipTextTransformer(text_config)
+ self.vision_model = SiglipVisionTransformer(vision_config)
+
+ self.logit_scale = nn.Parameter(torch.randn(1))
+ self.logit_bias = nn.Parameter(torch.randn(1))
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
+ def get_text_features(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ text_features (`torch.FloatTensor` of shape `(batch_size,
+ output_dim`): The text embeddings obtained by
+ applying the projection layer to the pooled output
+ of [`SiglipTextModel`].
+ Examples:
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModel
+ >>> import torch
+ >>> model = AutoModel.from_pretrained(
+ "google/siglip-base-patch16-224")
+ >>> tokenizer = AutoTokenizer.from_pretrained(
+ "google/siglip-base-patch16-224")
+ >>> # important: make sure to set padding="max_length" as that's
+ how the model was trained
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"],
+ padding="max_length", return_tensors="pt")
+ >>> with torch.no_grad():
+ ... text_features = model.get_text_features(**inputs)
+ ```"""
+ # Use SigLIP model's config for some fields (if specified) instead
+ # of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None\
+ else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else
+ self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None \
+ else self.config.use_return_dict
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = text_outputs[1]
+
+ return pooled_output
+
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
+ def get_image_features(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ image_features (`torch.FloatTensor` of shape `(batch_size,
+ output_dim`): The image embeddings obtained by applying the
+ projection layer to the pooled output of [`SiglipVisionModel`].
+ Examples:
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> import torch
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained(
+ "google/siglip-base-patch16-224")
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> inputs = processor(images=image, return_tensors="pt")
+ >>> with torch.no_grad():
+ ... image_features = model.get_image_features(**inputs)
+ ```"""
+ # Use SiglipModel's config for some fields (if specified) instead
+ # of those of vision & text components.
+ output_attentions = output_attentions if output_attentions \
+ is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else
+ self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else \
+ self.config.use_return_dict
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = vision_outputs[1]
+
+ return pooled_output
+
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=SiglipOutput,
+ config_class=SiglipConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ return_loss: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SiglipOutput]:
+ r"""
+ Returns:
+ Examples:
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> import torch
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained(
+ "google/siglip-base-patch16-224")
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
+ >>> # important: we pass `padding=max_length` since the model was
+ trained with this
+ >>> inputs = processor(text=texts, images=image,
+ padding="max_length", return_tensors="pt")
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+ >>> logits_per_image = outputs.logits_per_image
+ >>> probs = torch.sigmoid(logits_per_image) # these are the
+ probabilities
+ >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
+ 31.9% that image 0 is 'a photo of 2 cats'
+ ```"""
+ # Use SigLIP model's config for some fields (if specified) instead of
+ # those of vision & text components.
+ output_attentions = output_attentions if output_attentions \
+ is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else
+ self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else \
+ self.config.use_return_dict
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ image_embeds = vision_outputs[1]
+ text_embeds = text_outputs[1]
+
+ # normalized features
+ image_embeds = image_embeds / image_embeds.norm(
+ p=2, dim=-1, keepdim=True)
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t(
+ )) * self.logit_scale.exp() + self.logit_bias
+ logits_per_image = logits_per_text.t()
+
+ loss = None
+ if return_loss:
+ raise NotImplementedError("SigLIP loss to be implemented")
+
+ if not return_dict:
+ output = (logits_per_image, logits_per_text, text_embeds,
+ image_embeds, text_outputs, vision_outputs)
+ return ((loss, ) + output) if loss is not None else output
+
+ return SiglipOutput(
+ loss=loss,
+ logits_per_image=logits_per_image,
+ logits_per_text=logits_per_text,
+ text_embeds=text_embeds,
+ image_embeds=image_embeds,
+ text_model_output=text_outputs,
+ vision_model_output=vision_outputs,
+ )
+
+
+def get_siglip_vision_model(_flash_attn_2_enabled=True, **kwargs):
+ siglip_vision_config = {
+ "hidden_size": 1152,
+ "image_size": 448,
+ "intermediate_size": 4304,
+ "model_type": "siglip_vision_model",
+ "num_attention_heads": 16,
+ "num_hidden_layers": 27,
+ "patch_size": 14,
+ }
+
+ # Detect attention implementation.
+ attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
+ if attn_backend != _Backend.FLASH_ATTN:
+ _flash_attn_2_enabled = False
+
+ model_config = SiglipVisionConfig(
+ **siglip_vision_config,
+ _flash_attn_2_enabled=_flash_attn_2_enabled,
+ **kwargs)
+
+ vision_model = SiglipVisionModel(model_config).vision_model
+
+ return vision_model