diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 409a4d1210bc3..fc363585b0e7e 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -410,7 +410,7 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ - * `Phi3ForCausalLM` * Phi-4, Phi-3 - * `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. + * `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. * ✅︎ * ✅︎ - * `Phi3SmallForCausalLM` @@ -856,6 +856,13 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ +- * `Phi4MMForCausalLM` + * Phi-4-multimodal + * T + I+ / T + A+ / I+ + A+ + * `microsoft/Phi-4-multimodal-instruct`, etc. + * ✅︎ + * + * - * `PixtralForConditionalGeneration` * Pixtral * T + I+ diff --git a/requirements-common.txt b/requirements-common.txt index da6bc2af68cf2..1ca64c9948974 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -37,3 +37,4 @@ depyf==0.18.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py watchfiles # required for http server to monitor the updates of TLS files python-json-logger # Used by logging as per examples/other/logging_configuration.md +scipy # Required for phi-4-multimodal-instruct \ No newline at end of file diff --git a/tests/models/registry.py b/tests/models/registry.py index 97db33b46fade..3c3247eaf3e99 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -272,6 +272,8 @@ def check_available_online( extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501 "Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True), + "Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", + trust_remote_code=True), "PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501 tokenizer_mode="mistral"), "QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL", diff --git a/vllm/config.py b/vllm/config.py index f87d2d6e82cf8..3f1bff4981294 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2284,9 +2284,9 @@ def compute_hash(self) -> str: return hash_str def __post_init__(self): - # Setting the maximum rank to 256 should be able to satisfy the vast + # Setting the maximum rank to 512 should be able to satisfy the vast # majority of applications. - possible_max_ranks = (8, 16, 32, 64, 128, 256) + possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512) possible_lora_extra_vocab_size = (0, 256, 512) if self.max_lora_rank not in possible_max_ranks: raise ValueError( diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index b05842dd27d3b..8f906cf1d80b6 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -395,6 +395,8 @@ def _placeholder_str(self, modality: ModalityStr, if model_type == "phi3_v": # Workaround since this token is not defined in the tokenizer return f"<|image_{current_count}|>" + if model_type == "phi4mm": + return "<|endoftext10|>" # 200010 (see vocab.json in hf model) if model_type in ("minicpmo", "minicpmv"): return "(./)" if model_type in ("blip-2", "chatglm", "fuyu", "paligemma", @@ -424,6 +426,8 @@ def _placeholder_str(self, modality: ModalityStr, elif modality == "audio": if model_type == "ultravox": return "<|audio|>" + if model_type == "phi4mm": + return "<|endoftext11|>" # 200011 (see vocab.json in hf model) if model_type == "qwen2_audio": return (f"Audio {current_count}: " f"<|audio_bos|><|AUDIO|><|audio_eos|>") diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py new file mode 100644 index 0000000000000..27ae9bcca2e43 --- /dev/null +++ b/vllm/model_executor/models/phi4mm.py @@ -0,0 +1,1803 @@ +# SPDX-License-Identifier: Apache-2.0 +import math +import re +from functools import lru_cache +from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) + +import numpy as np +import scipy.signal +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from transformers import PretrainedConfig +from transformers.utils import logging + +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext) +from vllm.inputs.data import TokenInputs, token_inputs +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) +from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalInputs, NestedTensors +from vllm.sequence import IntermediateTensors, SequenceData +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config + +from .interfaces import SupportsLoRA, SupportsMultiModal +from .phi4mm_audio import AudioEmbedding +from .utils import maybe_prefix +from .vision_siglip_navit import get_siglip_vision_model + +# <|endoftext10|> (see vocab.json in hf model) +_IMAGE_PLACEHOLDER_TOKEN_ID = 200010 +# <|endoftext11|> +_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 + +_AUDIO_MAX_SOUNDFILE_SIZE = 241_000 +DUMMY_SAMPLING_FREQUENCY = 16_000 # kHz + +DYNAMIC_HD = 16 +AUDIO_TOKEN_PATTERN = r"<\|audio_(\d+)\|>" +IMAGE_TOKEN_PATTERN = r"<\|image_(\d+)\|>" + +SIGLIP_NAME = "siglip-so400m-patch14-448" +VISION_ENCODER_TO_PROCESSING_CONFIG = { + 'siglip-so400m-patch14-448': { + 'dynamic_hd': 16, + 'vit_image_size': 448, + 'vit_patch_size': 14, + 'token_compression_factor': 2, + }, +} +logger = logging.get_logger(__name__) +# This is a workaround to prevent text (user input) + audio + image +# from being used in the same prompt. +# It includes token ids for "/n" and tokens in added_tokens_decoder +# from the tokenizer_confg.json file. +NON_USER_INPUT_TOKENS = { + 198, 200010, 200011, 199999, 200018, 200019, 200020, 200021, 200022, + 200023, 200024, 200025, 200026, 200027, 200028 +} + + +def get_max_dummy_image(ctx: InputContext): + hf_config = ctx.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] + dynamic_hd_size = prepro_config['dynamic_hd'] + vit_image_size = prepro_config['vit_image_size'] + + max_side = vit_image_size * dynamic_hd_size + dummy_image = dummy_image_for_phi4mm(vit_image_size, max_side) + return dummy_image + + +# image token length +def get_max_phi4mm_image_tokens(ctx: InputContext): + dummy_image = get_max_dummy_image(ctx) + + hf_config = ctx.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] + dynamic_hd_size = prepro_config['dynamic_hd'] + vit_image_size = prepro_config['vit_image_size'] + vit_patch_size = prepro_config['vit_patch_size'] + token_compression_factor = prepro_config['token_compression_factor'] + + image_num_tokens = _compute_num_image_tokens(dummy_image, dynamic_hd_size, + vit_image_size, + vit_patch_size, + token_compression_factor) + return image_num_tokens + + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, + image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def _find_target_aspect_ratio(image, image_size, max_num, min_num): + orig_width, orig_height = image.size + + w_crop_num = math.ceil(orig_width / float(image_size)) + h_crop_num = math.ceil(orig_height / float(image_size)) + if w_crop_num * h_crop_num > max_num: + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set((i, j) for i in range(1, max_num + 1) + for j in range(1, max_num + 1) + if i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + logger.debug("target_aspect_ratio: %s", target_aspect_ratio) + else: + target_width = image_size * w_crop_num + target_height = image_size * h_crop_num + target_aspect_ratio = (w_crop_num, h_crop_num) + return target_aspect_ratio, target_height, target_width + + +def _get_padding_size(image, target_height, target_width): + orig_width, orig_height = image.size + ratio_width = target_width / orig_width + ratio_height = target_height / orig_height + + if ratio_width < ratio_height: + padding_width = 0 + padding_height = target_height - int(orig_height * ratio_width) + else: + padding_width = target_width - int(orig_width * ratio_height) + padding_height = 0 + return padding_height, padding_width + + +def dynamic_preprocess(image, + min_num=1, + max_num=12, + image_size=384, + mask_size=27): + target_aspect_ratio, target_height, target_width =\ + _find_target_aspect_ratio( + image, image_size, max_num, min_num) + padding_height, padding_width = _get_padding_size(image, target_height, + target_width) + + # Calculate the ratio + orig_width, orig_height = image.size + ratio_width = target_width / orig_width + ratio_height = target_height / orig_height + if ratio_width < ratio_height: + new_size = (target_width, int(orig_height * ratio_width)) + else: + new_size = (int(orig_width * ratio_height), target_height) + + attention_mask = torch.ones((int(mask_size * target_aspect_ratio[1]), + int(mask_size * target_aspect_ratio[0]))) + if padding_width >= 14: + attention_mask[:, -math.floor(padding_width / 14):] = 0 + if padding_height >= 14: + attention_mask[-math.floor(padding_height / 14):, :] = 0 + assert attention_mask.sum( + ) > 0, f'attention mask is empty {attention_mask}' + + if min(new_size[1], target_height) < 10 or min(new_size[0], + target_width) < 10: + raise ValueError(f'the aspect ratio is very extreme {new_size}') + + image = T.functional.resize( + image, + [new_size[1], new_size[0]], + ) + + resized_img = T.functional.pad(image, + [0, 0, padding_width, padding_height], + fill=[255, 255, 255]) + + return resized_img, attention_mask + + +def pad_to_max_num_crops(images, max_crops=5): + """ + images: B x 3 x H x W, B<=max_crops + """ + B, _, H, W = images.shape + if max_crops > B: + pad = torch.zeros(max_crops - B, + 3, + H, + W, + dtype=images.dtype, + device=images.device) + images = torch.cat([images, pad], dim=0) + return images + + +def pad_mask_to_max_num_crops(masks, max_crops=5): + B, H, W = masks.shape + if max_crops > B: + pad = torch.ones(max_crops - B, + H, + W, + dtype=masks.dtype, + device=masks.device) + masks = torch.cat([masks, pad], dim=0) + return masks + + +def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size): + + # Basic settings. + img_processor = T.Compose([ + T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]) + # Dynamic HD + base_resolution = vit_resolution + images = [image.convert('RGB') for image in images] + # cover 384 and 448 resolution + mask_resolution = base_resolution // vit_patch_size + elems, image_attention_masks = [], [] + for im in images: + elem, attention_mask = dynamic_preprocess(im, + max_num=dynamic_hd_size, + image_size=base_resolution, + mask_size=mask_resolution) + elems.append(elem) + image_attention_masks.append(attention_mask) + hd_images = [img_processor(im) for im in elems] + global_image = [ + torch.nn.functional.interpolate( + im.unsqueeze(0).float(), + size=(base_resolution, base_resolution), + mode='bicubic', + ).to(im.dtype) for im in hd_images + ] + shapes = [[im.size(1), im.size(2)] for im in hd_images] + mask_shapes = [[mask.size(0), mask.size(1)] + for mask in image_attention_masks] + global_attention_mask = [ + torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images + ] + hd_images_reshape = [ + im.reshape(1, 3, h // base_resolution, base_resolution, + w // base_resolution, base_resolution).permute( + 0, 2, 4, 1, 3, 5).reshape(-1, 3, base_resolution, + base_resolution).contiguous() + for im, (h, w) in zip(hd_images, shapes) + ] + attention_masks_reshape = [ + mask.reshape(1, h // mask_resolution, mask_resolution, + w // mask_resolution, mask_resolution).permute( + 0, 1, 3, 2, 4).reshape(-1, mask_resolution, + mask_resolution).contiguous() + for mask, (h, w) in zip(image_attention_masks, mask_shapes) + ] + # NOTE token compression is hard coded here, and odd numbers seems to fail + downsample_attention_masks = [ + mask[:, 0::2, + 0::2].reshape(1, h // mask_resolution, w // mask_resolution, + mask_resolution // 2 + mask_resolution % 2, + mask_resolution // 2 + mask_resolution % 2).permute( + 0, 1, 3, 2, 4) + for mask, (h, w) in zip(attention_masks_reshape, mask_shapes) + ] + downsample_attention_masks = [ + mask.reshape(mask.size(1) * mask.size(2), + mask.size(3) * mask.size(4)) + for mask in downsample_attention_masks + ] + # NOTE hard coded number of tokens + num_img_tokens = [ + 256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16 + for mask in downsample_attention_masks + ] + + hd_images_reshape = [ + torch.cat([_global_image] + [_im], dim=0) + for _global_image, _im in zip(global_image, hd_images_reshape) + ] + hd_masks_reshape = [ + torch.cat([_global_mask] + [_mask], + dim=0) for _global_mask, _mask in zip( + global_attention_mask, attention_masks_reshape) + ] + max_crops = max([img.size(0) for img in hd_images_reshape]) + image_transformed = [ + pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape + ] + image_transformed = torch.stack(image_transformed, dim=0) + mask_transformed = [ + pad_mask_to_max_num_crops(mask, max_crops) \ + for mask in hd_masks_reshape + ] + mask_transformed = torch.stack(mask_transformed, dim=0) + + returned_input_image_embeds = image_transformed + returned_image_sizes = torch.tensor(shapes, dtype=torch.long) + returned_image_attention_mask = mask_transformed + returned_num_img_tokens = num_img_tokens + + data = { + "pixel_values": returned_input_image_embeds, + "image_sizes": returned_image_sizes, + "image_attention_mask": returned_image_attention_mask, + "num_img_tokens": returned_num_img_tokens, + } + return data + + +class Phi4MMImageEncoder(nn.Module): + """Image embedding.""" + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + model_dir: str = "") -> None: + super().__init__() + + # n_embed or hidden_size + hidden_size = config.n_embd if hasattr( + config, 'n_embd') else config.hidden_size + if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'): + embd_drop = config.embd_pdrop if hasattr( + config, 'embd_pdrop') else config.embed_pdrop + self.drop = nn.Dropout(embd_drop) + else: + self.drop = None + + # layer_idx to output the img features + if isinstance(config.img_processor, dict): + self.layer_idx = config.img_processor.get('layer_idx', -2) + self.type_feature = config.img_processor.get( + 'type_feature', 'patch') + else: + self.layer_idx = -2 + self.type_feature = 'patch' + + self.img_processor = get_siglip_vision_model( + _flash_attn_2_enabled=True) + + pe_weight = self.img_processor.embeddings.position_embedding.weight + L, D = pe_weight.size() + H = int(math.sqrt(L)) + assert H**2 == L, f'position embedding size {L} is not square' + if H % 2 != 0: + self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) + H += 1 + image_dim_out = D + # ((448/14)//2)**2 + self.num_img_tokens = (H // 2)**2 + self.base_feat_height_target = H + + self.image_dim_out = image_dim_out + self.img_sizes = None + self.image_attention_mask = None + + # global_gn and sub_gn for hd transform, serves as line separator + self.use_hd_transform = True + self.with_learnable_separator = True + self.hd_transform_order = "sub_glb" + self.freeze_img_processor = False + self.crop_size = 448 + + # image token compression + self.image_token_compression_cls = 'avg_pool_2d' + self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) + self.base_feat_height_reduction = 1 + self.base_feat_height_target = self.base_feat_height_target // 2 + + # with_hd_transform and with_learnable_separator should have same value + assert self.use_hd_transform == self.with_learnable_separator, \ + 'use_hd_transform and with_learnable_separator should have same value' + assert self.use_hd_transform, \ + 'learnable separator is only for hd transform' + # 1024 * 4, merge spatial to channel dimension + self.glb_GN = nn.Parameter( + torch.zeros([ + 1, 1, self.image_dim_out * self.base_feat_height_reduction**2 + ])) + self.sub_GN = nn.Parameter( + torch.zeros([ + 1, 1, 1, + self.image_dim_out * self.base_feat_height_reduction**2 + ])) + + dim_projection = hidden_size + depth = 2 + layers = [ + nn.Linear(image_dim_out * self.base_feat_height_reduction**2, + dim_projection) + ] + for _ in range(1, depth): + layers.extend( + [nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + + self.vocab_size = config.vocab_size + self.img_features = None + + self.use_out_place_operations = False + + def get_img_features(self, + img_embeds: torch.FloatTensor, + attention_mask=None) -> torch.FloatTensor: + LAYER_IDX = self.layer_idx + TYPE_FEATURE = self.type_feature + + img_processor_output = self.img_processor( + img_embeds, + output_hidden_states=True, + patch_attention_mask=attention_mask) + img_feature = img_processor_output.hidden_states[LAYER_IDX] + + if TYPE_FEATURE == "patch": + patch_feature = img_feature + + use_token_compression = self.image_token_compression is not None + use_padding = getattr(self, 'img_processor_padding', + None) is not None + if use_token_compression or use_padding: + # reshape to 2D tensor + width = int(math.sqrt(patch_feature.size(1))) + patch_feature = patch_feature.view(-1, width, width, + patch_feature.size(-1)) + # convert to NCHW + patch_feature = patch_feature.permute(0, 3, 1, 2) + + if use_padding: + patch_feature = self.img_processor_padding(patch_feature) + if use_token_compression: + patch_feature = self.image_token_compression(patch_feature) + + # convert to NHWC + patch_feature = patch_feature.permute(0, 2, 3, 1) + patch_feature = patch_feature.view( + -1, + patch_feature.size(1) * patch_feature.size(2), + patch_feature.size(-1)) + + return patch_feature + + raise NotImplementedError + + def forward(self, pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + image_attention_mask: torch.Tensor) -> torch.FloatTensor: + """ + process image and return vision embeddings. + + pixel_values: (num_images, num_crops, c, h, w) + image_sizes: [[h1, w1], [h2, w2]] + image_attention_mask: num_images x num_crops x 32 x 32 + output: (num_images, num_img_tokens, hidden_size) + """ + + # eg + # pixel_values: torch.Size([1, 7, 3, 448, 448]) + # image_sizes: tensor([[ 896, 1344]], device='cuda:0') + # output: torch.Size([1, 1841, 3072]) + + if isinstance(self.img_projection, nn.Sequential): + target_device = self.img_projection[0].bias.device + target_dtype = self.img_projection[0].bias.dtype + else: # It's a single nn.Linear layer + target_device = self.img_projection.bias.device + target_dtype = self.img_projection.bias.dtype + + img_sizes = image_sizes + num_images, num_crops, c, h, w = pixel_values.shape + bs = num_images + pixel_values = pixel_values.flatten(0, 1) + + img_features = self.get_img_features( + pixel_values, + image_attention_mask.type(torch.BoolTensor).flatten( + 0, 1).to(target_device)) + + base_feat_height_target = self.base_feat_height_target + base_resolution = self.crop_size + base_feat_height_reduction = self.base_feat_height_reduction + + base_feat_height = base_feat_width = int(np.sqrt( + img_features.shape[1])) + assert base_feat_height == base_feat_height_target \ + and base_feat_width == base_feat_height_target, \ + f'base_feat_height: {base_feat_height},"\ + f" base_feat_width: {base_feat_width}, "\ + f"expect {base_feat_height_target} features for hd transform' + + # bs x max_num_crops x (24x24) x C + img_features = img_features.view(bs, -1, + base_feat_height * base_feat_width, + self.image_dim_out) + C = self.image_dim_out + H = base_feat_height + + output_imgs = [] + output_len = [] + # training is tensor, inference is list + if isinstance(img_sizes, torch.Tensor): + img_sizes = img_sizes.view(-1, 2) + for _bs in range(bs): + h, w = img_sizes[_bs] + h = h // base_resolution + w = w // base_resolution + B_ = h * w + + # 1 x (24x24) x 1024 + global_img_feature = img_features[_bs, :1] + + # 1 x 12 x 12 x 4096 + glb_img = global_img_feature.reshape(1, H, H, C).reshape( + 1, H // base_feat_height_reduction, base_feat_height_reduction, + H // base_feat_height_reduction, base_feat_height_reduction, + C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape( + 1, H // base_feat_height_reduction, + H // base_feat_height_reduction, + base_feat_height_reduction * base_feat_height_reduction * + C).contiguous() + temp_glb_GN = self.sub_GN.repeat(1, + H // base_feat_height_reduction, + 1, 1) + + # 1 x 156 x 4096 + glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape( + 1, -1, + base_feat_height_reduction * base_feat_height_reduction * C) + + # (max_num_crops-1) x (12x12) x C + sub_img = img_features[_bs, 1:] + # 16x574x1024 + # get rid of padding sub_img + sub_img = sub_img[:B_] + + # (num_crops, 12, 2, 12, 2, 1024) -> + # (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024) + sub_img = sub_img.reshape(B_, H, H, C).reshape( + B_, H // base_feat_height_reduction, + base_feat_height_reduction, H // base_feat_height_reduction, + base_feat_height_reduction, + C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape( + B_, -1, base_feat_height_reduction * + base_feat_height_reduction * C).contiguous() + sub_img = sub_img.reshape( + 1, h, w, base_feat_height // base_feat_height_reduction, + base_feat_width // base_feat_height_reduction, + -1).permute(0, 1, 3, 2, 4, 5).reshape( + 1, h * base_feat_height // base_feat_height_reduction, + w * base_feat_width // base_feat_height_reduction, + base_feat_height_reduction * base_feat_height_reduction * + C) + + if image_attention_mask is not None and len( + image_attention_mask) > 0: + reshaped_image_attention_mask = image_attention_mask[ + _bs, 1:B_ + 1, 0::2, 0::2].reshape( + 1, h, w, + base_feat_height // base_feat_height_reduction, + base_feat_width // base_feat_height_reduction).permute( + 0, 1, 3, 2, 4).reshape( + 1, h * base_feat_height // + base_feat_height_reduction, w * + base_feat_width // base_feat_height_reduction) + useful_height = int( + reshaped_image_attention_mask[0, :, 0].sum().item()) + useful_width = int( + reshaped_image_attention_mask[0, 0, :].sum().item()) + sub_img = sub_img[:, :useful_height, :useful_width] + temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1) + temp_len = int( + image_attention_mask[_bs, :B_ + 1, 0::2, 0::2].sum().item( + )) + (useful_height + + 1) + base_feat_height // base_feat_height_reduction + else: + temp_sub_GN = self.sub_GN.repeat( + 1, h * base_feat_height // base_feat_height_reduction, 1, + 1) + temp_len = int((h * w + 1) * self.num_img_tokens + 1 + + (h + 1) * base_feat_height // + base_feat_height_reduction) + + sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape( + 1, -1, + base_feat_height_reduction * base_feat_height_reduction * C) + # (1, num_img_tokens, 1024*4) + + # glb + sub + if self.hd_transform_order == 'glb_sub': + output_imgs.append( + torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) + elif self.hd_transform_order == 'sub_glb': + output_imgs.append( + torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) + else: + raise NotImplementedError( + f'hd_transform_order = {self.hd_transform_order}, "\ + "not implemented') + + #temp_len = int((h*w+1)*144 + 1 + (h+1)*12) + assert temp_len == output_imgs[-1].shape[ + 1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: "\ + "{output_imgs[-1].shape[1]}' + + output_len.append(temp_len) + + img_set_tensor = [] + for _output_img in output_imgs: + img_feature_proj = self.img_projection( + _output_img.to(target_device).to(target_dtype)) + img_set_tensor.append(img_feature_proj) + + return img_set_tensor + + +class Phi4MMAudioFeatureInputs(TypedDict): + type: Literal["audio_features"] + data: Tuple[NestedTensors] + """Shape: `((batch_size, num_audios, 80, M), )""" + + +class Phi4MMAudioEmbeddingInputs(TypedDict): + type: Literal["audio_embeds"] + data: NestedTensors + """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" + + +Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] + + +def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None): + """Create a Mel filter-bank the same as SpeechLib FbankFC. + + Args: + sample_rate (int): Sample rate in Hz. number > 0 [scalar] + n_fft (int): FFT size. int > 0 [scalar] + n_mel (int): Mel filter size. int > 0 [scalar] + fmin (float): lowest frequency (in Hz). If None use 0.0. + float >= 0 [scalar] + fmax: highest frequency (in Hz). If None use sample_rate / 2. + float >= 0 [scalar] + + Returns + out (numpy.ndarray): Mel transform matrix + [shape=(n_mels, 1 + n_fft/2)] + """ + + bank_width = int(n_fft // 2 + 1) + if fmax is None: + fmax = sample_rate / 2 + if fmin is None: + fmin = 0 + assert fmin >= 0, "fmin cannot be negative" + assert (fmin < fmax <= + sample_rate / 2), "fmax must be between (fmin, samplerate / 2]" + + def mel(f): + return 1127.0 * np.log(1.0 + f / 700.0) + + def bin2mel(fft_bin): + return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0)) + + def f2bin(f): + return int((f * n_fft / sample_rate) + 0.5) + + # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1] + klo = f2bin(fmin) + 1 + khi = f2bin(fmax) + + khi = max(khi, klo) + + # Spec 2: SpeechLib uses triangles in Mel space + mlo = mel(fmin) + mhi = mel(fmax) + m_centers = np.linspace(mlo, mhi, n_mels + 2) + ms = (mhi - mlo) / (n_mels + 1) + + matrix = np.zeros((n_mels, bank_width), dtype=np.float32) + for m in range(0, n_mels): + left = m_centers[m] + center = m_centers[m + 1] + right = m_centers[m + 2] + for fft_bin in range(klo, khi): + mbin = bin2mel(fft_bin) + if left < mbin < right: + matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms + + return matrix + + +class LogFbankProcessor: + + def __init__(self): + + self._eightk_method = "fillzero" + self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T + + self._hamming400 = np.hamming(400) # for 16k audio + self._hamming200 = np.hamming(200) # for 8k audio + + def extract_spectrogram(self, wav, fs): + """Extract spectrogram features from waveform. + Args: + wav (1D array): waveform of the input + fs (int): sampling rate of the waveform, 16000 or 8000. + If fs=8000, the waveform will be resampled to 16000Hz. + Output: + log_fbank (2D array): a TxD matrix of log Mel filterbank features. + D=80, and T is the number of frames. + """ + if wav.ndim > 1: + wav = np.squeeze(wav) + + # by default, we extract the mean if stereo + if len(wav.shape) == 2: + wav = wav.mean(1) + + # Resample to 16000 or 8000 if needed + if fs > 16000: + wav = scipy.signal.resample_poly(wav, 1, fs // 16000) + fs = 16000 + elif 8000 < fs < 16000: + wav = scipy.signal.resample_poly(wav, 1, fs // 8000) + fs = 8000 + elif fs < 8000: + raise RuntimeError(f"Unsupported sample rate {fs}") + + if fs == 8000: + if self._eightk_method == "resample": + # Input audio is 8 kHz. Convert to 16 kHz before feature + # extraction + wav = scipy.signal.resample_poly(wav, 2, 1) + fs = 16000 + # Do nothing here for fillzero method + elif fs != 16000: + # Input audio is not a supported sample rate. + raise RuntimeError( + f"Input data using an unsupported sample rate: {fs}") + + preemphasis = 0.97 + + if fs == 8000: + n_fft = 256 + win_length = 200 + hop_length = 80 + fft_window = self._hamming200 + elif fs == 16000: + n_fft = 512 + win_length = 400 + hop_length = 160 + fft_window = self._hamming400 + + # Spec 1: SpeechLib cut remaining sample insufficient for a hop + n_batch = (wav.shape[0] - win_length) // hop_length + 1 + # Here we don't use stride_tricks since the input array may not satisfy + # memory layout requirement and we need writeable output + # Here we only use list of views before copy to destination + # so it is more efficient than broadcasting + y_frames = np.array( + [ + wav[_stride:_stride + win_length] + for _stride in range(0, hop_length * n_batch, hop_length) + ], + dtype=np.float32, + ) + + # Spec 2: SpeechLib applies preemphasis within each batch + y_frames_prev = np.roll(y_frames, 1, axis=1) + y_frames_prev[:, 0] = y_frames_prev[:, 1] + y_frames = (y_frames - preemphasis * y_frames_prev) * 32768 + + S = np.fft.rfft(fft_window * y_frames, n=n_fft, + axis=1).astype(np.complex64) + + if fs == 8000: + # Need to pad the output to look like 16 kHz data but with zeros in + # the 4 to 8 kHz bins. + frames, bins = S.shape + padarray = np.zeros((frames, bins)) + S = np.concatenate((S[:, 0:-1], padarray), + axis=1) # Nyquist bin gets set to zero + + spec = np.abs(S).astype(np.float32) + return spec + + def extract_features(self, wav, fs): + """Extract log filterbank features from waveform. + Args: + wav (1D array): waveform of the input + fs (int): sampling rate of the waveform, 16000 or 8000. + If fs=8000, the waveform will be resampled to 16000Hz. + Output: + log_fbank (2D array): a TxD matrix of log Mel filterbank features. + D=80, and T is the number of frames. + """ + spec = self.extract_spectrogram(wav, fs) + spec_power = spec**2 + + fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None) + log_fbank = np.log(fbank_power).astype(np.float32) + + return log_fbank + + +@lru_cache +def audio_feature_extractor() -> LogFbankProcessor: + # Creates an instance of the audio processor, needed to extract the + # the audio features from the sound file + # LRU cache ensures that we only make one copy + return LogFbankProcessor() + + +def _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size, + vit_patch_size, token_compression_factor): + """ + compute the number of tokens an image is expected to take up considering + the image encoder architecture and exclude output features containing + only padding pixels + + for siglip, vit_image_size=448, vit_patch_size=14, so output will be + 32x32 feature map + NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 + """ + assert vit_image_size % vit_patch_size == 0, \ + "vit_image_size must be divisible by vit_patch_size" + assert vit_image_size // vit_patch_size % token_compression_factor == 0, \ + "vit_image_size // vit_patch_size must be divisible by "\ + "token_compression_factor" + + target_aspect_ratio, target_height, target_width = ( + _find_target_aspect_ratio(image, + vit_image_size, + dynamic_hd_size, + min_num=1)) + assert target_aspect_ratio[ + 0] * vit_image_size == target_width, \ + f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}" + assert target_aspect_ratio[ + 1] * vit_image_size == target_height, \ + f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}" + assert (target_height % vit_image_size == 0 + and target_width % vit_image_size == 0) + + padding_height, padding_width = _get_padding_size(image, target_height, + target_width) + assert padding_width == 0 or padding_height == 0, \ + "padding_width or padding_height must be 0" + + target_feat_width = target_width // vit_patch_size + target_feat_height = target_height // vit_patch_size + if padding_width >= vit_patch_size: + assert padding_height == 0, "padding_height not 0" + non_pad_feat_width = target_feat_width - math.floor( + padding_width / vit_patch_size) + non_pad_feat_height = target_feat_height + elif padding_height >= vit_patch_size: + assert padding_width == 0, "padding_width not 0" + non_pad_feat_height = target_feat_height - math.floor( + padding_height / vit_patch_size) + non_pad_feat_width = target_feat_width + else: + # small padding shorter than a vit patch + non_pad_feat_width = target_feat_width + non_pad_feat_height = target_feat_height + + feat_width = non_pad_feat_width // token_compression_factor + feat_height = non_pad_feat_height // token_compression_factor + # NOTE it's possible that the non-padding feature is not divisible + if non_pad_feat_width % token_compression_factor != 0: + feat_width += 1 + if non_pad_feat_height % token_compression_factor != 0: + feat_height += 1 + num_hd_patch_tokens = feat_width * feat_height + num_hd_newline_tokens = feat_height + vit_feature_size = vit_image_size // vit_patch_size + num_global_image_tokens = (vit_feature_size // token_compression_factor)**2 + num_sep_tokens = 1 + num_global_image_newline_tokens = \ + vit_feature_size // token_compression_factor + + return (num_global_image_tokens + num_sep_tokens + num_hd_patch_tokens + + num_hd_newline_tokens + num_global_image_newline_tokens) + + +def compute_logfbank_output_size(wav_length: int, fs: int) -> Tuple[int, int]: + """ + Compute the output size of the `extract_features` method. + + Args: + wav_length (int): Length of the input waveform in samples. + fs (int): Sampling rate of the waveform, either 16000 or 8000. + + Returns: + tuple (int, int): Output size as (T, D), where: + T: Number of time frames. + D: Number of Mel filterbank bins (80). + """ + + # Resample to 16000 or 8000 if needed + if fs > 16000: + wav_length //= fs // 16000 + fs = 16000 + elif 8000 <= fs < 16000: + # We'll resample to 16K from 8K + wav_length *= 2 + fs = 16000 + elif fs < 8000: + raise RuntimeError(f"Unsupported sample rate {fs}") + + # Spectrogram parameters for 16 kHz + win_length = 400 # Frame length in samples + hop_length = 160 # Frame shift in samples + mel_bins = 80 # Number of mel filterbank bins + + # Calculate number of frames (T) + T = (wav_length - win_length) // hop_length + 1 + if T < 1: + raise ValueError("Waveform too short for given parameters.") + + # Return time frames (T) and mel bins (D) + return T, mel_bins + + +def _get_audio_embed_sizes(audios, ctx: InputContext): + """ + Get the audio embedding sizes for each audio file. + + Args: + audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of + waveform and sample rate. + ctx (InputContext): Input context. + + Returns: + List[int]: List of audio embedding sizes. + """ + audio_embed_sizes = [] + for audio in audios: + audio_data, sf = audio + audio_frames, _ = compute_logfbank_output_size(len(audio_data), sf) + audio_embed_size = _compute_audio_embed_size(ctx.get_hf_config(), + audio_frames) + audio_embed_sizes.append(audio_embed_size) + return audio_embed_sizes + + +def _get_audio_id_to_input_ids(audios, ctx: InputContext, prompt_str=""): + """ + The following will search for `<|audio_{idx}|>` tokens and + return a mapping of audio placeholder tokens to audio placeholder token ids + based on the size of the audio embeddings. + + Args: + audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of + waveform and sample rate. + ctx (InputContext): Input context. + prompt_str (str): The prompt string. + + Returns: + Dict[str, List[int]]: Mapping of audio placeholder tokens to audio + placeholder token ids. + + """ + if len(audios) == 0: + return {} + + audio_embed_sizes = _get_audio_embed_sizes(audios, ctx) + audio_ids = re.findall(AUDIO_TOKEN_PATTERN, prompt_str) + audio_ids = [int(audio_id) for audio_id in audio_ids] + assert len(audio_ids) == len( + audio_embed_sizes + ), "Number of audio tokens and audio features do not match" + assert tuple(audio_ids) == tuple(range(1, + len(audio_ids) + + 1)), "Audio ids are not in order!" + audio_id_to_input_ids = { + f"<|audio_{audio_id}|>": + [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size + for audio_id, audio_embed_size in zip(audio_ids, audio_embed_sizes) + } + + return audio_id_to_input_ids + + +def _count_image_tokens(images, ctx: InputContext): + hf_config = ctx.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] + dynamic_hd_size = prepro_config['dynamic_hd'] + vit_image_size = prepro_config['vit_image_size'] + vit_patch_size = prepro_config['vit_patch_size'] + token_compression_factor = prepro_config['token_compression_factor'] + + image_token_counts = [ + _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size, + vit_patch_size, token_compression_factor) + for image in images + ] + return image_token_counts + + +def _get_image_id_to_input_ids(images, prompt, ctx: InputContext): + if len(images) == 0: + return {} + + image_ids = re.findall(IMAGE_TOKEN_PATTERN, prompt) + image_ids = [int(image_id) for image_id in image_ids] + assert len(image_ids) == len( + set(image_ids)), "Duplicate image tokens in prompt" + assert len(images) == len( + image_ids), "Number of images and image tokens in prompt do not match" + + # NOTE the following assertion is not strictly necessary + assert tuple(image_ids) == tuple(range(1, + len(image_ids) + + 1)), "Image ids are not in order" + + image_token_counts = _count_image_tokens(images, ctx) + image_id_to_input_ids = { + f"<|image_{image_id}|>": [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_tokens + for image_id, num_tokens in zip(image_ids, image_token_counts) + } + return image_id_to_input_ids + + +def input_processor_for_phi4mm(ctx: InputContext, + inputs: DecoderOnlyInputs) -> TokenInputs: + """ + Implements the input processor, which transforms the input prompt ids + to include the audio placeholder token. This will become the `input_ids` + in `forward` for the model. + + Args: + ctx (InputContext): Input context. + inputs (DecoderOnlyInputs): The inputs (e.g. prompt, prompt_token_ids) + to process. + + Returns: + TokenInputs: Processed inputs + """ + multi_modal_data = inputs.get("multi_modal_data") + if (multi_modal_data is None or + ("audio" not in multi_modal_data and "image" not in multi_modal_data)): + # pure text input, so no need to do pre-processing + return inputs + + prompt_str = inputs.get("prompt") + prompt_token_ids = inputs.get("prompt_token_ids") + # for offline_inference, we will get str input and we parse MM special + # tokens from it + # (ignore prompt_token_ids) + # for OAI server, we will get prompt_token_ids, where MM special tokens + # are already parsed + + if 'audio' in multi_modal_data: + audios = multi_modal_data["audio"] + + if not isinstance(audios, list): + audios = [audios] + if prompt_str is not None: + audio_id_to_input_ids = _get_audio_id_to_input_ids( + audios, ctx, prompt_str=prompt_str) + audio_embed_sizes = [] + elif prompt_token_ids is not None: + audio_id_to_input_ids = {} + audio_embed_sizes = _get_audio_embed_sizes(audios, ctx) + else: + audio_id_to_input_ids = {} + audio_embed_sizes = [] + + if 'image' in multi_modal_data: + # PIL Image or list of PIL Images + images = multi_modal_data["image"] + if not isinstance(images, list): + images = [images] + if prompt_str is not None: + image_id_to_input_ids = _get_image_id_to_input_ids( + images, prompt_str, ctx) + image_token_counts = [] + elif prompt_token_ids is not None: + image_id_to_input_ids = {} + image_token_counts = _count_image_tokens(images, ctx) + else: + image_id_to_input_ids = {} + image_token_counts = [] + + # Handle the case where the prompt is a string and we need to manually + # tokenize it. + # In this case, the `audio_id_to_input_ids` dict will be mapping from + # an audio placeholder + # string (e.g. `<|audio_1|>`) to the audio placeholder tokens for the + # given audio length. + if prompt_str: + pattern = r"(<\|image_\d+\|>|<\|audio_\d+\|>)" + prompt_chunk_strings = re.split(pattern, prompt_str) + prompt_chunk_strings = [s for s in prompt_chunk_strings if s != ""] + + # Create the new input_ids with the placeholder image and audio + # tokens inserted + tokenizer = cached_tokenizer_from_config(ctx.model_config) + input_ids = [] + has_imag, has_audio, has_user_text_input = False, False, False + for prompt_chunk_string in prompt_chunk_strings: + if re.match(IMAGE_TOKEN_PATTERN, prompt_chunk_string): + input_ids.extend(image_id_to_input_ids[prompt_chunk_string]) + has_imag = True + elif re.match(AUDIO_TOKEN_PATTERN, prompt_chunk_string): + input_ids.extend(audio_id_to_input_ids[prompt_chunk_string]) + has_audio = True + else: + curr_token_ids = tokenizer(prompt_chunk_string).input_ids + if not has_user_text_input: + for token_id in curr_token_ids: + if token_id not in NON_USER_INPUT_TOKENS: + has_user_text_input = True + break + input_ids.extend(curr_token_ids) + if has_audio and has_imag and has_user_text_input: + raise ValueError( + "Phi4MMForCausalLM does not support text + audio + image" + + " inputs in the same prompt") + # Handle the case where the prompt is already tokenized + else: + assert prompt_token_ids is not None, \ + "If string prompt isn't provided, prompt_token_ids must be" + + i = 0 + input_ids = prompt_token_ids + # only needed for later assertion + img_cnt, audio_cnt, user_text_input_cnt = 0, 0, 0 + image_token_count_iter = iter(image_token_counts) + audio_embed_size_iter = iter(audio_embed_sizes) + while i < len(input_ids): + token_id = input_ids[i] + if token_id == _AUDIO_PLACEHOLDER_TOKEN_ID: + token_count = next(audio_embed_size_iter) + audio_cnt += 1 + elif token_id == _IMAGE_PLACEHOLDER_TOKEN_ID: + token_count = next(image_token_count_iter) + img_cnt += 1 + else: + user_text_input_cnt += 1 if token_id not in \ + NON_USER_INPUT_TOKENS else 0 + i += 1 + continue + tokens = [token_id] * token_count + input_ids = input_ids[:i] + tokens + input_ids[i + 1:] + i += token_count + + if audio_cnt > 0 and img_cnt > 0 and user_text_input_cnt > 0: + raise ValueError( + "Phi4MMForCausalLM does not support text + audio + image" + + " inputs in the same prompt") + # If the below assertion fails, it might be that input pure-text + # messages contain image/audio special tokens literally + # (<|endoftext10|>, <|endoftext11|>). + assert (img_cnt == len(image_token_counts)), ( + f"Number of image tokens in prompt_token_ids ({img_cnt}) " + f"does not match number of images ({len(image_token_counts)})") + assert (audio_cnt == len(audio_embed_sizes)), ( + f"Number of audio tokens in prompt_token_ids ({audio_cnt}) " + f"does not match number of audios ({len(audio_embed_sizes)})") + + # NOTE: Create a defensive copy of the original inputs + return token_inputs( + prompt_token_ids=input_ids, + prompt=prompt_str, + multi_modal_data=multi_modal_data, + ) + + +def _compute_audio_embed_size(hf_config, audio_frames): + """ + Compute the audio embedding size based on the audio frames and + compression rate. + """ + compression_rate = hf_config.embd_layer['audio_embd_layer'][ + 'compression_rate'] + # NOTE: this is a hard-coded value but might be configurable in the future + qformer_compression_rate = 1 + integer = audio_frames // compression_rate + remainder = audio_frames % compression_rate + + result = integer if remainder == 0 else integer + 1 + + integer = result // qformer_compression_rate + remainder = result % qformer_compression_rate + result = integer if remainder == 0 else integer + 1 # qformer compression + + return result + + +def get_max_phi4mm_audio_tokens(ctx: InputContext) -> int: + return 10000 + + +def dummy_audio_for_phi4mm(audio_count: int) -> dict: + """ + Create dummy audio data for the Phi4MM model, which is used for profiling. + + Args: + audio_count (int): Number of audio samples. + + Returns: + dict: Dummy audio data. + """ + dummy_audio = np.full((_AUDIO_MAX_SOUNDFILE_SIZE, ), 0.0) + return [(dummy_audio, DUMMY_SAMPLING_FREQUENCY)] * audio_count + + +def dummy_image_for_phi4mm(width: int, height: int): + image = Image.new('RGB', (width, height), color='black') + return image + + +def dummy_data_for_phi4mm(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]) -> DummyData: + """ + Create dummy sequence (input_ids) and audio data for the Phi4MM model, + which is used for profiling. + + In this case, the sequence data is a bunch of 0s with a number of audio + tokens that correspond to the audio embed size of the + _AUDIO_MAX_SOUNDFILE_SIZE. + + Args: + ctx (InputContext): Input context. + seq_len (int): Length of the sequence. + mm_counts (Mapping[str, int]): Multi-modal counts. + + Returns: + Tuple: Dummy sequence data and dummy audio data. + """ + audio_count = mm_counts["audio"] + audio_frames, _ = compute_logfbank_output_size(_AUDIO_MAX_SOUNDFILE_SIZE, + DUMMY_SAMPLING_FREQUENCY) + audio_feature_size = _compute_audio_embed_size(ctx.get_hf_config(), + audio_frames) + + image_count = mm_counts["image"] + dummy_image = get_max_dummy_image(ctx) + max_image_tokens = get_max_phi4mm_image_tokens(ctx) + total_image_tokens = image_count * max_image_tokens + + if seq_len - audio_feature_size * audio_count - total_image_tokens < 0: + raise RuntimeError( + f"Phi4MM cannot process {audio_count} audios and {image_count}" + f"images in a prompt, please increase max_model_len to be at" + f" larger than " + f"{audio_feature_size * audio_count + total_image_tokens}" + " or reduce audio/image limit by --limit-mm-per-prompt.") + + if audio_feature_size * audio_count > total_image_tokens: + seq_data = SequenceData.from_prompt_token_counts( + (_AUDIO_PLACEHOLDER_TOKEN_ID, audio_feature_size * audio_count), + (0, seq_len - audio_feature_size * audio_count), + ) + mm_data = { + "audio": dummy_audio_for_phi4mm(audio_count), + } + else: + seq_data = SequenceData.from_prompt_token_counts( + (_IMAGE_PLACEHOLDER_TOKEN_ID, total_image_tokens), + (0, seq_len - total_image_tokens), + ) + mm_data = { + "image": [dummy_image] * image_count, + } + return DummyData(seq_data, mm_data) + + +def input_mapper_for_phi4mm_audio(ctx: InputContext, + data: object) -> MultiModalInputs: + """ + This function is used to create the MultiModalInputs for the Phi4MM + (audio) model. + Specifically, for audio, we extract the audio features from the sound + file and create pairs of audio features and audio embed lengths (the + latter of which is used to repeat the audio placeholder token in the + input prompt IDs). + These pairs are used, downstream, in `_audio_features_to_embeddings` + (via `_process_audio_input`). + + Note that the incoming audio data (each entry in `data`) is a tuple of + the audio data and the sampling frequency (e.g. from soundfile.read). + + Args: + ctx (InputContext): Input context. + data (object): Audio data. + + Returns: + MultiModalInputs: Multi-modal inputs. + """ + if not isinstance(data, list): + data = [data] + + if len(data) == 0: + return MultiModalInputs() + + audio_features = [] + for audio_input in data: + if not isinstance(audio_input, tuple): + raise NotImplementedError( + f"Unsupported data type: {type(audio_input)}") + + audio, sf = audio_input + feature_extractor = audio_feature_extractor() + single_audio_features = feature_extractor.extract_features(audio, sf) + feat_stride = (1 if not hasattr(feature_extractor, "stride") else + feature_extractor.stride) + audio_frames = len(single_audio_features) * feat_stride + single_audio_embed_size = _compute_audio_embed_size( + ctx.get_hf_config(), audio_frames) + single_audio_feature_audio_len_pair = ( + single_audio_features, + [single_audio_embed_size], + ) + audio_features.append(single_audio_feature_audio_len_pair) + return MultiModalInputs({"audio_features": audio_features}) + + +def input_mapper_for_phi4mm_image(ctx: InputContext, data: object): + if not isinstance(data, list): + data = [data] + # data: list of PIL images + if len(data) == 0: + return MultiModalInputs() + hf_config = ctx.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] + dynamic_hd_size = prepro_config['dynamic_hd'] + vit_image_size = prepro_config['vit_image_size'] + vit_patch_size = prepro_config['vit_patch_size'] + + image_input_dict = preprocess(data, dynamic_hd_size, vit_image_size, + vit_patch_size) + return MultiModalInputs({ + "pixel_values": + image_input_dict["pixel_values"], + "image_sizes": + image_input_dict["image_sizes"], + "image_attention_mask": + image_input_dict["image_attention_mask"], + "num_img_tokens": + image_input_dict["num_img_tokens"], + }) + + +def cat_with_pad(tensors, dim, padding_value=0): + """ + cat along dim, while pad to max for all other dims + """ + ndim = tensors[0].dim() + assert all( + t.dim() == ndim for t in + tensors[1:]), "All tensors must have the same number of dimensions" + + out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] + out_size[dim] = sum(t.shape[dim] for t in tensors) + output = tensors[0].new_full(out_size, padding_value) + + index = 0 + for t in tensors: + # Create a slice list where every dimension except dim is full slice + slices = [slice(0, t.shape[d]) for d in range(ndim)] + # Update only the concat dimension slice + slices[dim] = slice(index, index + t.shape[dim]) + + output[slices] = t + index += t.shape[dim] + + return output + + +@MULTIMODAL_REGISTRY.register_input_mapper("audio", + input_mapper_for_phi4mm_audio) +@MULTIMODAL_REGISTRY.register_input_mapper("image", + input_mapper_for_phi4mm_image) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "audio", get_max_phi4mm_audio_tokens) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "image", get_max_phi4mm_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi4mm) +@INPUT_REGISTRY.register_input_processor(input_processor_for_phi4mm) +class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): + """ + Implements the Phi-4-multimodal-instruct model in VLLM. + """ + # LoRA specific attributes + packed_modules_mapping = { + "qkv_proj": [ + "qkv_proj", + ], + "gate_up_proj": [ + "gate_up_proj", + ], + } + supported_lora_modules = [ + "qkv_proj", "o_proj", "gate_up_proj", "down_proj" + ] + # Phi4MMForCausalLM does not apply LoRA to the embedding layer. + embedding_modules = {} + embedding_padding_modules = [] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + assert multimodal_config, "multimodal_config is required" + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.multimodal_config = multimodal_config + self.quant_config = quant_config + self.lora_config = lora_config + + # Tensor/Pipeline parallel not supported for now. + assert get_tensor_model_parallel_world_size( + ) == 1, "tensor parallel is not supported" + assert get_pp_group( + ).world_size == 1, "pipeline parallel is not supported" + + self.vision_encoder = Phi4MMImageEncoder( + config, + quant_config, + prefix="model.vision_embed_tokens", + model_dir=config._name_or_path) + + if isinstance(config.embd_layer["audio_embd_layer"], dict): + embedding_config = { + "embedding_cls": + config.embd_layer["audio_embd_layer"]["embedding_cls"], + **config.embd_layer["audio_embd_layer"], + } + else: + embedding_config = { + "embedding_cls": self.config.embd_layer["embedding_cls"] + } + + self.embed_tokens_extend = AudioEmbedding(config, **embedding_config) + self.model = LlamaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size), + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() + + def _audio_features_to_embeddings( + self, + input_ids: torch.Tensor, + input_features: List[torch.Tensor], + audio_input_sizes: torch.Tensor, + audio_projection_mode: str, + ) -> torch.Tensor: + """ + Convert audio features to embeddings, which are used as input to the + model (via `inputs_embeds`). + + Args: + input_ids (torch.Tensor): Input IDs (the prompt in this case). + input_features (list[torch.Tensor]): Input features (the audio + embeddings). + audio_input_sizes (list[torch.Tensor]): Audio input sizes (the + audio embed lengths to use for padding the audio placeholder token + in the input prompt IDs). + """ + # The audio projection can either be a single linear or Sequential, + # so handle both cases + if isinstance(self.embed_tokens_extend.audio_projection, + nn.Sequential): + target_dtype = self.embed_tokens_extend.audio_projection[ + 0].bias.dtype + else: + target_dtype = self.embed_tokens_extend.audio_projection.bias.dtype + + audio_input = [ + input.unsqueeze(0).to(target_dtype) for input in input_features + ] + kwargs = { + "wte": self.model.embed_tokens, + 'audio_projection_mode': audio_projection_mode + } + audio_embeddings = self.embed_tokens_extend(input_ids, audio_input, + audio_input_sizes, + **kwargs) + audio_embeddings = audio_embeddings.to(target_dtype) + return audio_embeddings + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> Optional[Phi4MMAudioInputs]: + """ + Parse and validate the audio input to the model. This handles both + audio features and audio embeddings, but only the former is used for + now. + + Args: + kwargs (object): Keyword arguments. + + Returns: + Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs. + """ + audio_features = kwargs.pop("audio_features", None) + audio_embeds = kwargs.pop("audio_embeds", None) + + if audio_features is None and audio_embeds is None: + return None + + if audio_features is not None: + if not isinstance(audio_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio features. " + f"Got type: {type(audio_features)}") + + return Phi4MMAudioFeatureInputs(type="audio_features", + data=audio_features) + + if audio_embeds is not None: + if not isinstance(audio_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio embeds. " + f"Got type: {type(audio_embeds)}") + + return Phi4MMAudioEmbeddingInputs(type="audio_embeds", + data=audio_embeds) + + raise AssertionError("This line should be unreachable.") + + def _process_audio_input(self, input_ids: torch.Tensor, + audio_input: Phi4MMAudioInputs, + audio_projection_mode: str) -> NestedTensors: + """ + Create the audio embeddings from the audio input, where the audio input + is pairs of audio features and audio embed lengths. The audio input is + created by `input_mapper_for_phi4mm_audio`. + + Args: + input_ids (torch.Tensor): Input IDs (the prompt in this case, + before the audio token replication). + audio_input (Phi4MMAudioInputs): Audio input. + + Returns: + NestedTensors: Audio embeddings + """ + if audio_input["type"] == "audio_embeds": + return audio_input["data"] + + audio_features = audio_input["data"] + # (e.g. multiple examples) and the second dim is the multi-audio dim + # (e.g. multiple audios in the same example) + audio_feature = [i[0] for j in audio_features for i in j] + audio_feature_len = [i[1].item() for j in audio_features for i in j] + # Add the batch dim via `squeeze` + + return self._audio_features_to_embeddings( + input_ids.unsqueeze(0), + audio_feature, + audio_feature_len, + audio_projection_mode, + ).squeeze(0) + + def _parse_and_validate_image_input(self, + **kwargs: object) -> Optional[Dict]: + pixel_values: Optional[Dict] = kwargs.get("pixel_values") + if pixel_values is None: + return None + + image_sizes = kwargs.get("image_sizes") + image_attention_mask = kwargs.get("image_attention_mask") + num_img_tokens = kwargs.get("num_img_tokens") + assert image_sizes is not None and image_attention_mask is not None\ + and num_img_tokens is not None, "Missing image inputs" + + if isinstance(pixel_values, list): + assert pixel_values[0].dim() == 5, "Incorrect image inputs" + # list len is batch_size. + # each tensor has dimension: num_img_per_example, num_hd_patches, + # channels, height, width. + # need to pad along num_hd_patches. + # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w. + pixel_values = cat_with_pad(pixel_values, dim=0) + elif isinstance(pixel_values, torch.Tensor): + # dimension: batch_size, num_img_per_example, num_hd_patches, + # channels, height, width. + # we flatten first 2 dims to make it a single large batch for + # SigLIP Encoder. + assert pixel_values.dim() == 6, "Incorrect image inputs" + pixel_values = pixel_values.flatten(0, 1) + else: + raise ValueError("Incorrect pixel_values inputs") + + if isinstance(image_attention_mask, list): + image_attention_mask = cat_with_pad(image_attention_mask, dim=0) + elif isinstance(image_attention_mask, torch.Tensor): + image_attention_mask = image_attention_mask.flatten(0, 1) + else: + raise ValueError("Incorrect image_attention_mask inputs") + + if isinstance(image_sizes, list): + image_sizes = torch.cat(image_sizes, dim=0) + elif isinstance(image_sizes, torch.Tensor): + image_sizes = image_sizes.flatten(0, 1) + else: + raise ValueError("Incorrect image_attention_mask inputs") + + if isinstance(num_img_tokens, list): + num_img_tokens = [ + n for num_tensor in num_img_tokens + for n in num_tensor.tolist() + ] + elif isinstance(num_img_tokens, torch.Tensor): + num_img_tokens = num_img_tokens.flatten(0, 1).tolist() + else: + raise ValueError("Incorrect image_attention_mask inputs") + + return { + 'pixel_values': pixel_values, + 'image_sizes': image_sizes, + 'image_attention_mask': image_attention_mask, + 'num_img_tokens': num_img_tokens, + } + + def merge_image_features_to_inputs_embeds( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_set_tensors: List[torch.Tensor], + ): + position_tuple = (input_ids == _IMAGE_PLACEHOLDER_TOKEN_ID).nonzero( + as_tuple=True) + + assert all([t.shape[0] == 1 for t in image_set_tensors + ]), 'img_set_tensor should have shape (1, N_tokens, C)' + # Shape: (merged_N_tokens, C) + image_set_tensor = torch.cat(image_set_tensors, dim=1).squeeze(0) + image_set_tensor = image_set_tensor.to(inputs_embeds.dtype).to( + inputs_embeds.device) + merged_embeds = inputs_embeds.index_put( + indices=position_tuple, + values=image_set_tensor, + accumulate=False, + ) + return merged_embeds + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> None: + weights = {name: weight for name, weight in weights} + adjusted_weights = {} + + for name, weight in weights.items(): + # NOTE vision-speech tasks use a separate projection layer + audio_proj_4v = \ + "model.embed_tokens_extend.audio_embed.audio_projection.vision" + if name.startswith(audio_proj_4v): + name = name.replace( + audio_proj_4v, + "embed_tokens_extend.audio_projection_for_vision") + + name = (name.replace( + "model.embed_tokens_extend.audio_embed."\ + "audio_projection.speech.", + "embed_tokens_extend.audio_projection.", + ).replace( + "model.embed_tokens_extend.audio_embed.", + "embed_tokens_extend.", + ).replace("model.embed_tokens_extend.image_embed.", + "vision_encoder.")) + # NOTE: this is deal with LoRA injection, where `base_layer` + # remains as the original layer in the model + if name.endswith(".base_layer.weight"): + name = name.replace(".base_layer.weight", ".weight") + adjusted_weights[name] = weight + + missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights, + strict=False) + logger.debug("*** missing keys:") + for key in missing_keys: + logger.debug(key) + logger.debug("**** unexpected keys:") + for key in unexpected_keys: + logger.debug(key) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> torch.Tensor: + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + # Each entry in this is a pair of audio_features and audio_embed + # lengths + audio_input = self._parse_and_validate_audio_input(**kwargs) + image_inputs = self._parse_and_validate_image_input(**kwargs) + + has_audio = audio_input is not None + has_image = image_inputs is not None + + if has_audio: + audio_projection_mode = 'vision' if has_image else 'speech' + inputs_embeds = self._process_audio_input( + input_ids, audio_input, audio_projection_mode) + + if has_image: + dtype = self.vision_encoder.img_processor.embeddings.\ + patch_embedding.weight.dtype + pixel_values = image_inputs['pixel_values'].to(dtype) + image_sizes = image_inputs['image_sizes'] + image_attention_mask = image_inputs['image_attention_mask'] + image_set_tensors = self.vision_encoder( + pixel_values, image_sizes, image_attention_mask) + if not has_audio: + inputs_embeds = self.model.embed_tokens(input_ids) + + inputs_embeds = self.merge_image_features_to_inputs_embeds( + input_ids, inputs_embeds, image_set_tensors) + + if has_image or has_audio: + # multi-modal input, we have set inputs_embeds properly in + # previous steps + input_ids = None + else: + # text-only, we keep using original input_ids + inputs_embeds = None + + hidden_states = self.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py new file mode 100644 index 0000000000000..f9d4881c55e29 --- /dev/null +++ b/vllm/model_executor/models/phi4mm_audio.py @@ -0,0 +1,1403 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +# Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com) +# but implemented by the Phi-Speech team +#!/usr/bin/env python3 +import abc +import math +from functools import partial +from typing import Callable, Dict, List, Literal, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, CheckpointWrapper, checkpoint_wrapper, offload_wrapper) +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullyShardedDataParallel) +from torch.utils.checkpoint import checkpoint +from transformers import PretrainedConfig + +from vllm.model_executor.models.phi4mm_utils import ( + AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer, + MultiHeadedAttention, NemoConvSubsampling, T5RelativeAttentionLogitBias, + adaptive_enc_mask, attn_checkpointing, embedding_checkpoint_wrapper, + get_offset, repeat, unfold_tensor, validate_checkpointing_config) + +_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|> + + +def encoder_checkpoint_wrapper( + activation_checkpointing: Union[str, Dict], + layer_cls: type, + idx: int = 0, +) -> Callable: + """return encoder activation checkpoint wrapper""" + validate_checkpointing_config(activation_checkpointing) + + if isinstance(activation_checkpointing, str): + if activation_checkpointing: + if activation_checkpointing == "offload": + return offload_wrapper + return partial(checkpoint_wrapper) + return lambda x: x + + if isinstance(activation_checkpointing, dict): + target_layer_cls = activation_checkpointing.get( + "module", "transformer") + if target_layer_cls.lower() == "transformer": + target_layer_cls = ( + "EncoderLayer", + "ConformerEncoderLayer", + ) + elif target_layer_cls.lower() == "attention": + target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention") + checkpointing_interval = activation_checkpointing.get("interval", 1) + offloading = activation_checkpointing.get("offload", False) + impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get( + "reentrant", True) else CheckpointImpl.NO_REENTRANT) + + if (idx % checkpointing_interval == 0 + and layer_cls.__name__ in target_layer_cls): + if offloading: + return offload_wrapper + return partial(checkpoint_wrapper, checkpoint_impl=impl) + return lambda x: x + + raise ValueError("Invalid activation_checkpointing config") + + +class ConformerEncoderLayer(nn.Module): + """ConformerEncoder Layer module. + for more details see conformer paper: + https://arxiv.org/abs/2005.08100 + This module implement the Conformer block layer. + + Args: + d_model: int + attention dim. + ext_pw_out_channel: int + if > 0, ext_pw_out_channel is a dim channel size + for the last pointwise conv after swish activation. + depthwise_seperable_out_channel: int + if set different to 0, the number of + depthwise_seperable_out_channel will be used as a + channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + n_head: int + the number of heads for multihead attention module. + d_ffn: int + output size of the feed_forward blocks. + ext_pw_kernel_size: int + kernel size of the conv pointwise of the conformer. + kernel_size: int + kernel size. + dropout_rate: float + dropout rate. + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation + in ConvModule layer of the conformer. + default False + activation: str, optional + activation function name, + one of ["relu", "swish", "sigmoid"], + sigmoid activation is only used with "glu_in_fnn=True", + default "relu". + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + default 0. + chunk_size: int, optional + chunk_size for cnn. default 18 + conv_activation: str, optional + activation function used in ConvModule part + of the conformer, default "relu". + conv_glu_type: str, optional + activation function used for the glu inside + the ConvModule part of the conformer. + default: "sigmoid". + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + attention_innner_dim: int, optional + if equal to -1, attention dim for linears k/q/v is + equal to d_model. otherwise attention_innner_dim is used. + default -1. + attention_glu_type: str, optional + activation function for glu used in the multihead attention, + default "swish". + activation_checkpointing: str, optional + a dictionarry of {"module","interval","offload"}, where + "module": str + accept ["transformer", "attention"] to select + which module should do activation checkpointing. + "interval": int, default 1, + interval of applying activation checkpointing, + interval = 1 means that we apply checkpointing + on every layer (if activation), otherwise, + we apply it every x interval. + "offload": bool, default False, + if set to True, we offload activation to cpu and + reload it during backward, otherwise, + we recalculate activation in backward. + default "". + export: bool, optional + if set to True, it remove the padding from convolutional layers + and allow the onnx conversion for inference. + default False. + use_pt_scaled_dot_product_attention: bool, optional + if set to True, use pytorch's scaled dot product attention + implementation in training. + attn_group_sizes: int, optional + the number of groups to use for attention, default 1 + (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attn_group_sizes < attention_heads = Grouped-Query Attention + attn_group_sizes = attenion_heads = Multi-Query Attention + """ + + def __init__( + self, + d_model=512, + ext_pw_out_channel=0, + depthwise_seperable_out_channel=256, + depthwise_multiplier=1, + n_head=4, + d_ffn=2048, + ext_pw_kernel_size=1, + kernel_size=3, + dropout_rate=0.1, + causal=False, + batch_norm=False, + activation="relu", + chunk_se=0, + chunk_size=18, + conv_activation="relu", + conv_glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + attention_innner_dim=-1, + attention_glu_type="swish", + activation_checkpointing="", + export=False, + use_pt_scaled_dot_product_attention=False, + attn_group_sizes: int = 1, + ): + super().__init__() + + self.feed_forward_in = FeedForward( + d_model=d_model, + d_inner=d_ffn, + dropout_rate=dropout_rate, + activation=activation, + bias_in_glu=bias_in_glu, + ) + + self.self_attn = encoder_checkpoint_wrapper( + activation_checkpointing, + MultiHeadedAttention, + )(MultiHeadedAttention( + n_head, + d_model, + dropout_rate, + attention_innner_dim, + attention_glu_type, + bias_in_glu, + use_pt_scaled_dot_product_attention= + use_pt_scaled_dot_product_attention, + group_size=attn_group_sizes, + )) + self.conv = ConvModule( + d_model, + ext_pw_out_channel, + depthwise_seperable_out_channel, + ext_pw_kernel_size, + kernel_size, + depthwise_multiplier, + dropout_rate, + causal, + batch_norm, + chunk_se, + chunk_size, + conv_activation, + conv_glu_type, + bias_in_glu, + linear_glu_in_convm, + export=export, + ) + + self.feed_forward_out = FeedForward( + d_model=d_model, + d_inner=d_ffn, + dropout_rate=dropout_rate, + activation=activation, + bias_in_glu=bias_in_glu, + ) + + self.layer_norm_att = nn.LayerNorm(d_model) + self.layer_norm = nn.LayerNorm(d_model) + + def forward( + self, + x, + pos_k, + pos_v, + mask, + relative_attention_bias: Optional[Tensor] = None, + ): + """ConformerEncoder forward. + + Args: + x: torch.Tensor + input feature of shape (batch, max_time_in, size) + pos_k: torch.Tensor + positional key embedding. + mask: torch.Tensor + mask for x (batch, max_time_in) + relative_attention_bias: Optional[torch.Tensor] + bias added to attention logits w.r.t. relative positions + (1, n_head, time1, time2) + """ + x = x + 0.5 * self.feed_forward_in(x) + norm_x = self.layer_norm_att(x) + + x = x + self.self_attn( + norm_x, + norm_x, + norm_x, + pos_k, + pos_v, + mask, + relative_attention_bias=relative_attention_bias, + ) + x = x + self.conv(x) + x = x + 0.5 * self.feed_forward_out(x) + + out = self.layer_norm(x) + + return out, pos_k, pos_v, mask + + +class TransformerEncoderBase(abc.ABC, nn.Module): + """The Base class for Transformer based encoders + + Please set causal = True in streaming model + Args: + input_size: int + input feature dimension. + chunk_size: int, list(int) + Number of frames for each chunk + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training + Some examples for the 2 cases: + chunk_size = 12 + chunk_size = [6, 8, 12, 24] + left_chunk: int, list(int) + Number of chunks used for masking in streaming mode. + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training. When + chunk_size is a list, left_chunk must be a list with same length. + Some examples for the 2 cases: + left_chunk = 6 + left_chunk = [12, 9, 6, 3] + attention_dim: int, optional + attention dimension. default 256. + attention_heads: int, optional + the number of heads. default 4 + input_layer: str, optional + input layer type before Conformer, + one of ["linear", "conv2d", "custom", "vgg2l", "embed"], + default "conv2d" + cnn_out: int, optional + the number of CNN channels before Conformer. + default -1. + cnn_layer_norm: bool, optional + layer norm between Conformer and the first CNN. + default False. + time_reduction: int, optional + time reduction factor + default 4 + dropout_rate: float, optional + dropout rate. default 0.1 + padding_idx: int, optional + padding index for input_layer=embed + default -1 + relative_attention_bias_args: dict, optional + use more efficient scalar bias-based relative multihead attention + (Q*K^T + B) implemented in cmb.basics.embedding. + [T5/ALiBi]RelativeAttentionLogitBias + usage: relative_attention_bias_args={"type": t5/alibi} + additional method-specific arguments can be provided (see + transformer_base.py) + positional_dropout_rate: float, optional + dropout rate after positional encoding. default 0.0 + nemo_conv_settings: dict, optional + A dictionary of settings for NeMo Subsampling. + default None + conv2d_extra_padding: str, optional + Add extra padding in conv2d subsampling layers. Choices are + (feat, feat_time, none, True). + if True or feat_time, the extra padding is added into non full + supraframe utts in batch. + Default: none + attention_group_size: int, optional + the number of groups to use for attention, default 1 + (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attention_group_size < attention_heads = Grouped-Query + Attention + attention_group_size = attenion_heads = Multi-Query Attention + """ + + def __init__( + self, + input_size, + chunk_size, + left_chunk, + attention_dim=256, + attention_heads=4, + input_layer="nemo_conv", + cnn_out=-1, + cnn_layer_norm=False, + time_reduction=4, + dropout_rate=0.0, + padding_idx=-1, + relative_attention_bias_args=None, + positional_dropout_rate=0.0, + nemo_conv_settings=None, + conv2d_extra_padding: Literal["feat", "feat_time", "none", + True] = "none", + attention_group_size=1, + encoder_embedding_config=None, + ): + super().__init__() + self.input_size = input_size + self.input_layer = input_layer + self.chunk_size = chunk_size + self.left_chunk = left_chunk + self.attention_dim = attention_dim + self.num_heads = attention_heads + self.attention_group_size = attention_group_size + self.time_reduction = time_reduction + self.nemo_conv_settings = nemo_conv_settings + self.encoder_embedding_config = encoder_embedding_config + + if self.input_layer == "nemo_conv": + default_nemo_conv_settings = { + "subsampling": "dw_striding", + "subsampling_factor": self.time_reduction, + "feat_in": input_size, + "feat_out": attention_dim, + "conv_channels": 256, + "subsampling_conv_chunking_factor": 1, + "activation": nn.ReLU(), + "is_causal": False, + } + # Override any of the defaults with the incoming, user settings + if nemo_conv_settings: + default_nemo_conv_settings.update(nemo_conv_settings) + for i in ["subsampling_factor", "feat_in", "feat_out"]: + assert ( + i not in nemo_conv_settings + ), "{i} should be specified outside of the NeMo dictionary" + + self.embed = NemoConvSubsampling(**default_nemo_conv_settings, ) + else: + raise ValueError("unknown input_layer: " + input_layer) + + self.pos_emb = AbsolutePositionalEncoding(attention_dim, + positional_dropout_rate) + + self.relative_attention_bias_type = ( + relative_attention_bias_args.get("type") + if relative_attention_bias_args else None) + if self.relative_attention_bias_type == "t5": + assert (self.num_heads % self.attention_group_size == 0 + ), "attention_group_size must divide n_head" + self.relative_attention_bias_layer = T5RelativeAttentionLogitBias( + self.num_heads // self.attention_group_size, + max_distance=relative_attention_bias_args.get( + "t5_bias_max_distance", 1000), + symmetric=relative_attention_bias_args.get( + "t5_bias_symmetric", False), + ) + else: + raise NotImplementedError + + def post_init(self, init_model_config): + + pretrained_speech_encoder_path = init_model_config.get( + "pretrained_speech_encoder_path", None) + if pretrained_speech_encoder_path: + model_state = torch.load(pretrained_speech_encoder_path, + map_location="cpu") + encoder_state_dict = {} + for k, v in model_state.items(): + if "encoder." in k: + tmp_k = k.replace("encoder.", "") + encoder_state_dict[tmp_k] = v + + if hasattr(self, "encoder_embedding"): + del self.encoder_embedding + self.load_state_dict(encoder_state_dict) + + if not hasattr(self, "encoder_embedding"): + self.encoder_embedding = MeanVarianceNormLayer( + self.encoder_embedding_config["input_size"]) + + def compute_lens_change(self, feature_lens): + """feature_lens: int + return updated feature lens. + + This used to return a different lambda function for each case that + computed the right thing. That does not work within Torchscript. + If you really need this to be faster, create nn.Module()-s for all + the cases and return one of them. Torchscript does support that. + """ + if self.input_layer == "nemo_conv": + # Handle the special causal case + subsampling_causal_cond = self.nemo_conv_settings.get( + "subsampling", "dw_striding") in [ + "dw_striding", + "striding", + "striding_conv1d", + ] + is_causal = self.nemo_conv_settings.get("is_causal", False) + if is_causal and subsampling_causal_cond: + lens_change = (torch.ceil(feature_lens / + self.time_reduction).long() + if isinstance(feature_lens, Tensor) else + math.ceil(feature_lens / self.time_reduction)) + feature_lens_remainder = feature_lens % self.time_reduction + if isinstance(feature_lens, Tensor): + lens_change[feature_lens_remainder != 1] += 1 + elif feature_lens_remainder != 1: + lens_change += 1 + return lens_change + ceil_func = (math.ceil + if isinstance(feature_lens, int) else torch.ceil) + return ceil_func(feature_lens / self.time_reduction) + + @abc.abstractmethod + def forward(self): + """Abstract forward method implementation.""" + + def _chunk_size_selection(self, chunk_size=None, left_chunk=None): + """If chunk size is a list, we will randomly select a chunk size.""" + + if chunk_size is None: + chunk_size = self.chunk_size + if left_chunk is None: + left_chunk = self.left_chunk + if isinstance(chunk_size, list): + # Variable chunk size during training + chunk_size_index = int( + torch.randint(low=0, high=len(chunk_size), size=(1, ))) + chunk_size_train_eff = chunk_size[chunk_size_index] + if not isinstance(left_chunk, list): + raise ValueError( + "Since chunk_size is a list, left_chunk must be a list") + if len(left_chunk) != len(chunk_size): + raise ValueError( + "The length of left_chunk must be the same as length of "\ + "chunk_size." + ) + left_chunk_train_eff = left_chunk[chunk_size_index] + else: + chunk_size_train_eff = chunk_size + left_chunk_train_eff = left_chunk + + return chunk_size_train_eff, left_chunk_train_eff + + def _get_embed_class(self, embed): + # pylint: disable=protected-access + is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) + is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) + embed_class = embed + if is_embed_using_act_chkpt: + embed_class = embed._checkpoint_wrapped_module + if is_embed_fsdp_wrapped: + embed_class = embed.module + return embed_class + + def _forward_embeddings_core(self, input_tensor, masks): + embed_class = self._get_embed_class(self.embed) + assert isinstance(embed_class, NemoConvSubsampling) + input_tensor, masks = self.embed(input_tensor, masks) + return input_tensor, masks + + def _position_embedding(self, input_tensor): + pos_k = None + pos_v = None + if self.relative_attention_bias_layer is None: + input_tensor = self.pos_emb( + input_tensor) # default to add abs sinusoid embedding + return pos_k, pos_v + + def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): + chunk_size_train_eff, left_chunk_train_eff = \ + self._chunk_size_selection(chunk_size, left_chunk) + + # Create mask matrix for streaming + # S stores start index. if chunksize is 18, s is [0,18,36,....] + chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff) + # avoid randomness when run evaluation or decoding + if self.training and np.random.rand() > 0.5: + # Either first or last chunk is not complete. + # If only the last one is not complete, EOS is not effective + chunk_start_idx = seq_len - chunk_start_idx + chunk_start_idx = chunk_start_idx[::-1] + chunk_start_idx = chunk_start_idx[:-1] + chunk_start_idx = np.insert(chunk_start_idx, 0, 0) + + enc_streaming_mask = (adaptive_enc_mask( + seq_len, chunk_start_idx, + left_window=left_chunk_train_eff).unsqueeze(0).expand( + [batch_size, -1, -1])) + return enc_streaming_mask + + def forward_embeddings(self, + xs_pad, + masks, + chunk_size_nc=None, + left_chunk_nc=None): + """Forwarding the inputs through the top embedding layers + + Args: + xs_pad: torch.Tensor + input tensor + masks: torch.Tensor + input mask + chunk_size_nc: (optional, default is None) chunk size for + non-causal layers + left_chunk_nc: (optional, default is None) # of left chunks for + non-causal layers + """ + # pylint: disable=R0915 + # get new lens. + seq_len = int(self.compute_lens_change(xs_pad.shape[1])) + if seq_len <= 0: + raise ValueError( + f"""The sequence length after time reduction is invalid: + {seq_len}. Your input feature is too short. Consider + filtering out the very short sentence from data + loader""", ) + + batch_size = xs_pad.shape[0] + + enc_streaming_mask = self._streaming_mask(seq_len, batch_size, + self.chunk_size, + self.left_chunk) + + if xs_pad.is_cuda: + enc_streaming_mask = enc_streaming_mask.cuda() + xs_pad = xs_pad.cuda() + + input_tensor = xs_pad + input_tensor, masks = self._forward_embeddings_core( + input_tensor, masks) + + streaming_mask = enc_streaming_mask + if streaming_mask is not None and masks is not None: + hs_mask = masks & streaming_mask + elif masks is not None: + hs_mask = masks + else: + hs_mask = streaming_mask + + if chunk_size_nc is not None: + enc_streaming_mask_nc = self._streaming_mask( + seq_len, batch_size, chunk_size_nc, left_chunk_nc) + if xs_pad.is_cuda: + enc_streaming_mask_nc = enc_streaming_mask_nc.cuda() + if masks is not None: + hs_mask_nc = masks & enc_streaming_mask_nc + else: + hs_mask_nc = enc_streaming_mask_nc + else: + hs_mask_nc = None + + pos_k, pos_v = self._position_embedding(input_tensor) + + if chunk_size_nc is None: + return input_tensor, pos_k, pos_v, hs_mask, masks + return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc + + def get_offset(self): + """Returns offset used when retaining inputs for decoding. + + This is essentially, how many additional frames have to be added to + the front-end CNN input to ensure it can produce a single output. + So if the "padding" parameter is 0, typically offset will be > 0. + """ + return get_offset(self.input_layer, self.time_reduction) + + +class ConformerEncoder(TransformerEncoderBase): + """ConformerEncoder module. + see original paper for more details: + https://arxiv.org/abs/2005.08100 + + Please set causal = True in streaming model + Args: + input_size: int + input feature dimension. + chunk_size: int, list(int) + Number of frames for each chunk + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training + Some examples for the 2 cases: + chunk_size = 12 + chunk_size = [6, 8, 12, 24] + left_chunk: int, list(int) + Number of chunks used for masking in streaming mode. + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training. When + chunk_size is a list, left_chunk must be a list with same length. + Some examples for the 2 cases: + left_chunk = 6 + left_chunk = [12, 9, 6, 3] + left_chunk: int + number of chunks used for masking in streaming mode. + num_lang: int + This parameter is used to store the number of languages in the + lang_dict, only used for multiseed/multilingual models. + default None. + attention_dim: int, optional + attention dimension. default 256. + attention_heads: int, optional + the number of heads. default 4 + linear_units: + the number of units of position-wise feed forward. + default 2048 + num_block: + number of Transformer layer. default 6 + dropout_rate: float, optional + dropout rate. default 0.1 + input_layer: str, optional + input layer type before Conformer, + one of ["linear", "conv2d", "custom", "vgg2l", "embed"], + default "conv2d" + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation + in ConvModule layer of the conformer. + default False + cnn_out: int, optional + the number of CNN channels before Conformer. + default -1. + cnn_layer_norm: bool, optional + layer norm between Conformer and the first CNN. + default False. + ext_pw_out_channel: int, optional + the number of channel for CNN + before depthwise_seperable_CNN. + If 0 then use linear. default 0. + ext_pw_kernel_size: int, optional + kernel size of N before depthwise_seperable_CNN. + only work for ext_pw_out_channel > 0. + default 1 + depthwise_seperable_out_channel: int, optional + the number of channel for + depthwise_seperable_CNN. + default 256. + depthwise_multiplier: int, optional + the number of multiplier for + depthwise_seperable_CNN. + default 1. + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + default 0. + kernel_size: int, optional + the number of kernels for depthwise_seperable_CNN. + default 3. + activation: str, optional + FeedForward block activation. + one of ["relu", "swish", "sigmoid"] + default "relu". + conv_activation: str, optional + activation function used in ConvModule part + of the conformer, default "relu". + conv_glu_type: str, optional + activation used use glu in depthwise_seperable_CNN, + default "sigmoid" + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. default True + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + attention_glu_type: str + only work for glu_in_attention !=0 + default "swish". + export: bool, optional + if set to True, it remove the padding from convolutional layers + and allow the onnx conversion for inference. + default False. + activation_checkpointing: str, optional + a dictionarry of {"module","interval","offload"}, where + "module": str + accept ["transformer", "attention"] to select + which module should do activation checkpointing. + "interval": int, default 1, + interval of applying activation checkpointing, + interval = 1 means that we apply checkpointing + on every layer (if activation), otherwise, + we apply it every x interval. + "offload": bool, default False, + if set to True, we offload activation to cpu and + reload it during backward, otherwise, + we recalculate activation in backward. + default "". + extra_layer_output_idx: int + the layer index to be exposed. + relative_attention_bias_args: dict, optional + use more efficient scalar bias-based relative multihead attention + (Q*K^T + B) implemented in cmb.basics.embedding. + [T5/ALiBi]RelativeAttentionLogitBias + usage: relative_attention_bias_args={"type": t5/alibi} + additional method-specific arguments can be provided (see + transformer_base.py) + time_reduction: int optional + time reduction factor + default 4 + use_pt_scaled_dot_product_attention: whether to use pytorch scaled + dot product attention in training. + Default: False + nemo_conv_settings: dict, optional + A dictionary of settings for NeMo Subsampling. + default: None + usage: nemo_conv_settings= + { + "subsampling": + dw_striding/striding/dw_striding_conv1d/striding_conv1d, + "conv_channels": int, + "subsampling_conv_chunking_factor": int, + "is_causal": True/False + } + conv2d_extra_padding: str, optional + Add extra padding in conv2d subsampling layers. Choices are + (feat, feat_time, none, True) + Default: none + replication_pad_for_subsample_embedding: For batched-streaming + decoding, use "replication" padding for the cache at start of + utterance. + Default: False + attention_group_size: int, optional + the number of groups to use for attention, default 1 + (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attention_group_size < attention_heads = Grouped-Query + Attention + attention_group_size = attenion_heads = Multi-Query Attention + """ + + extra_multi_layer_output_idxs: List[int] + + def __init__( # pylint: disable-all + self, + input_size, + chunk_size, + left_chunk, + num_lang=None, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + input_layer="nemo_conv", + causal=True, + batch_norm=False, + cnn_out=-1, + cnn_layer_norm=False, + ext_pw_out_channel=0, + ext_pw_kernel_size=1, + depthwise_seperable_out_channel=256, + depthwise_multiplier=1, + chunk_se=0, + kernel_size=3, + activation="relu", + conv_activation="relu", + conv_glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + attention_glu_type="swish", + export=False, + extra_layer_output_idx=-1, + extra_multi_layer_output_idxs=[], # noqa + activation_checkpointing="", + relative_attention_bias_args=None, + time_reduction=4, + use_pt_scaled_dot_product_attention=False, + nemo_conv_settings=None, + conv2d_extra_padding: Literal["feat", "feat_time", "none", + True] = "none", + replication_pad_for_subsample_embedding=False, + attention_group_size=1, + encoder_embedding_config=None, + ): + super().__init__( + input_size, + chunk_size, + left_chunk, + attention_dim, + attention_heads, + input_layer, + cnn_out, + cnn_layer_norm, + time_reduction, + dropout_rate=dropout_rate, + relative_attention_bias_args=relative_attention_bias_args, + positional_dropout_rate=0.0, + nemo_conv_settings=nemo_conv_settings, + conv2d_extra_padding=conv2d_extra_padding, + attention_group_size=attention_group_size, + encoder_embedding_config=encoder_embedding_config, + ) + self.num_blocks = num_blocks + self.num_lang = num_lang + self.kernel_size = kernel_size + self.embed = embedding_checkpoint_wrapper(activation_checkpointing)( + self.embed) + self.replication_pad_for_subsample_embedding: bool = ( + replication_pad_for_subsample_embedding) + assert (self.num_heads % attention_group_size == 0 + ), "attention_group_size must divide n_head" + self.num_heads_k = self.num_heads // attention_group_size + + self.encoders = repeat( + num_blocks, + lambda i: encoder_checkpoint_wrapper(activation_checkpointing, + ConformerEncoderLayer, i) + (ConformerEncoderLayer( + d_model=attention_dim, + ext_pw_out_channel=ext_pw_out_channel, + depthwise_seperable_out_channel= + depthwise_seperable_out_channel, + depthwise_multiplier=depthwise_multiplier, + n_head=attention_heads, + d_ffn=linear_units, + ext_pw_kernel_size=ext_pw_kernel_size, + kernel_size=kernel_size, + dropout_rate=dropout_rate, + causal=causal, + batch_norm=batch_norm, + activation=activation, + chunk_se=chunk_se, + chunk_size=chunk_size, + conv_activation=conv_activation, + conv_glu_type=conv_glu_type, + bias_in_glu=bias_in_glu, + linear_glu_in_convm=linear_glu_in_convm, + attention_glu_type=attention_glu_type, + activation_checkpointing=attn_checkpointing( + activation_checkpointing, i), + export=export, + use_pt_scaled_dot_product_attention= + use_pt_scaled_dot_product_attention, + attn_group_sizes=attention_group_size, + )), + ) + self.extra_layer_output_idx = extra_layer_output_idx + self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs + # Make a zeros scalar we can use in get_initial_state to determine + # the device and the needed dtype: + self.register_buffer("dev_type", torch.zeros(()), persistent=False) + + def init_relative_attention_bias(self, input_tensor): + if self.relative_attention_bias_layer: + return self.relative_attention_bias_layer(input_tensor) + + def calculate_hs_mask(self, xs_pad, device, mask): + max_audio_length = xs_pad.shape[1] + batch_size = xs_pad.shape[0] + enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size, + self.chunk_size, + self.left_chunk) + enc_streaming_mask = enc_streaming_mask.to(device) + if mask is None: + return enc_streaming_mask + + feature_lens = mask.sum(1) + padding_length = feature_lens + pad_mask = (torch.arange(0, max_audio_length, + device=device).expand(padding_length.size(0), + -1) + < padding_length.unsqueeze(1)) + pad_mask = pad_mask.unsqueeze(1) + pad_mask = pad_mask & enc_streaming_mask + return pad_mask + + @torch.jit.ignore + def forward(self, xs_pad, masks): + """Conformer Forward function + + Args: + xs_pad: torch.Tensor + input tensor + masks: torch.Tensor + post-embedding input lengths + """ + xs_pad = self.encoder_embedding(xs_pad) + input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings( + xs_pad, masks) + + unfolded = False + ori_bz, seq_len, D = input_tensor.shape + max_seq_len = 500 #maximum position for absolute positional encoding + if seq_len > max_seq_len: + # audio sequence is longer than max_seq_len, unfold it into chunks + # of max_seq_len + unfolded = True + # the unfold op will drop residual frames, pad it to the multiple + # of max_seq_len + if seq_len % max_seq_len > 0: + chunk_pad_size = max_seq_len - (seq_len % max_seq_len) + else: + chunk_pad_size = 0 + if chunk_pad_size > 0: + input_tensor_pad = F.pad(input_tensor, + (0, 0, 0, chunk_pad_size), "constant", + 0) + input_tensor = input_tensor_pad.to(input_tensor.device) + input_tensor = unfold_tensor(input_tensor, max_seq_len) + if masks is not None: + # revise hs_mask here because the previous calculated hs_mask + # did not consider extra pad + subsampled_pad_mask = masks.squeeze( + 1) # [bz, subsampled_unmask_seq_len] + extra_padded_subsamlped_pad_mask = F.pad( + subsampled_pad_mask, (0, chunk_pad_size), "constant", + False) # extra padding to the pad mask + extra_padded_subsamlped_pad_mask = \ + extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() + masks_unfold = unfold_tensor( + extra_padded_subsamlped_pad_mask, max_seq_len + ) # unfold the pad mask like we did to the input tensor + masks_unfold = masks_unfold.squeeze( + -1).bool() # unfold op does not support bool tensor + else: + masks_unfold = None + hs_mask = self.calculate_hs_mask( + input_tensor, input_tensor.device, masks_unfold + ) # calculate hs_mask based on the unfolded pad mask + + # layer_emb = None + + relative_attention_bias = self.init_relative_attention_bias( + input_tensor) + + _simplified_path = (self.extra_layer_output_idx == -1 + and relative_attention_bias is None) + + if _simplified_path: + input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, + hs_mask) + else: + for i, layer in enumerate(self.encoders): + input_tensor, _, _, _ = layer( + input_tensor, + pos_k, + pos_v, + hs_mask, + relative_attention_bias=relative_attention_bias, + ) + + # if i == self.extra_layer_output_idx: + # layer_emb = input_tensor + + if unfolded: + embed_dim = input_tensor.shape[-1] + input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim) + # if we ever padded before unfolding, we need to remove the padding + if chunk_pad_size > 0: + input_tensor = input_tensor[:, :-chunk_pad_size, :] + + return input_tensor, masks # , layer_emb + + def gradient_checkpointing_enable(self): + pass + + +class WindowQformer(nn.Module): + """Window-level Qformer""" + + def __init__( + self, + window_size: int = 8, + num_queries: int = 1, + num_blocks: int = 2, + attention_dim: int = 512, + attention_heads: int = 8, + linear_units: int = 2048, + dropout_rate: float = 0.0, + normalize_before: bool = True, + ): + super().__init__() + + self.decoders = nn.ModuleList([ + nn.TransformerDecoderLayer( + d_model=attention_dim, + nhead=attention_heads, + dim_feedforward=linear_units, + dropout=dropout_rate, + activation="relu", + batch_first=True, + norm_first=normalize_before, # TODO need to verify + ) for _ in range(num_blocks) + ]) + + self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim)) + self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12) + if normalize_before else None) + self.window_size = window_size + self.gradient_checkpointing_enable = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing_enable = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing_enable = False + + def forward(self, audio_embed, mask, embed_len=None): + """forward decoder""" + # audio_embed: N x T x D => N x D x T + + audio_embed = audio_embed.transpose(1, 2) + # audio_embed: N x D x 1 x T => N x DK x T' + padding = audio_embed.shape[-1] % self.window_size + if padding > 0: + audio_embed = F.pad(audio_embed, (0, self.window_size - padding), + "constant", 0) + + embed_chunk = F.unfold( + audio_embed[..., None, :], + kernel_size=(1, self.window_size), + stride=(1, self.window_size), + ) + bsz, _, slen = embed_chunk.shape + # N x D x K x T' + embed_chunk = embed_chunk.view(bsz, -1, self.window_size, slen) + # N x T' x K x D + embed_chunk = embed_chunk.transpose(1, 3).contiguous() + # NT' x K x D + embed_chunk = embed_chunk.view(bsz * slen, self.window_size, -1) + # NT' x 1 x D + q = self.queries.expand(bsz * slen, -1, -1) + for layer in self.decoders: + if self.gradient_checkpointing_enable and self.training: + q = checkpoint( + layer.__call__, + q, + embed_chunk, + None, + mask, + use_reentrant=True, + ) + else: + q = layer(tgt=q, + memory=embed_chunk, + tgt_mask=None, + memory_mask=mask) + + if self.after_norm is not None: + q = self.after_norm(q) + + if embed_len is not None: + embed_len = embed_len // self.window_size + # N x T' x D + out = q.view(bsz, slen, -1) + + return out, embed_len + + +class AudioEmbedding(nn.Module): + """Image embedding.""" + + def __init__(self, config: PretrainedConfig, **kwargs) -> None: + super().__init__() + self.config = config + # n_embed or hidden_size for text LM + hidden_size = (config.n_embd + if hasattr(config, "n_embd") else config.hidden_size) + + if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"): + embd_drop = (config.embd_pdrop if hasattr(config, "embd_pdrop") + else config.embed_pdrop) + self.drop = nn.Dropout(embd_drop) + else: + self.drop = None + + # self.wte = nn.Embedding(config.vocab_size, hidden_size) + + audio_dim_out = ( + None # Set this variable according to the actual audio processor + ) + self.layer_idx = -2 + + if (isinstance(config.audio_processor, dict) + and config.audio_processor.get("name", None) == "cascades"): + encoder_config = config.audio_processor.get("config", None) + assert encoder_config is not None + self.encoder = ConformerEncoder(**encoder_config) + + # fake initialization, create encoder_embedding layer only so that + # in decoding, all parameters can be loaded in + # from_pretrained_function in training, we do post init after + # from_pretrained function to make sure the correct initialization + self.encoder.post_init({}) + + audio_dim_out = encoder_config["attention_dim"] + n_mels = encoder_config["input_size"] + else: + raise NotImplementedError("") + + assert (audio_dim_out + is not None), "Remember to set values for audio_dim_out" + self.audio_dim_out = audio_dim_out + self.audio_dim_in = n_mels + + self.freeze_audio_processor = kwargs.get("freeze_audio_processor", + False) + + self.downsample_rate = kwargs.get("downsample_rate", 1) + + if kwargs.get("use_qformer", False): + qformer_config = kwargs.get("qformer_config", {}) + qformer_config["attention_dim"] = audio_dim_out + self.qformer = WindowQformer(**qformer_config) + else: + self.qformer = None + + if kwargs.get("use_conv_downsample", False): + assert (self.qformer is None + ), "don't support use qformer and conv downsample together" + nemo_conv_settings = kwargs.get("nemo_conv_settings", {}) + default_nemo_conv_settings = { + "subsampling": "dw_striding", + "subsampling_factor": self.downsample_rate, + "feat_in": audio_dim_out, + "feat_out": audio_dim_out, + "conv_channels": 256, + "subsampling_conv_chunking_factor": 1, + "activation": nn.ReLU(), + "is_causal": False, + } + # Override any of the defaults with the incoming, user settings + if nemo_conv_settings: + default_nemo_conv_settings.update(nemo_conv_settings) + for i in ["subsampling_factor", "feat_in", "feat_out"]: + assert ( + i not in nemo_conv_settings + ), "{i} should be specified outside of the NeMo dictionary" + + self.conv_ds = NemoConvSubsampling(**default_nemo_conv_settings, ) + else: + self.conv_ds = None + + enable_gradient_checkpointing = kwargs.get( + "enable_gradient_checkpointing", False) + if enable_gradient_checkpointing: + self.encoder.gradient_checkpointing_enable() + + if self.qformer: + self.qformer.enable_gradient_checkpointing() + + projection_cls = kwargs.get("projection_cls", "linear") + if projection_cls == "linear": + self.audio_projection = nn.Linear(audio_dim_out, hidden_size) + elif projection_cls == "mlp": + # follow llava-v1.5's implementation + # (do not use image_projection and image_proj_norm) + dim_projection = hidden_size + depth = 2 + self.linear_downsample_rate = (1 if (self.qformer or self.conv_ds) + else self.downsample_rate) + layers = [ + nn.Linear(audio_dim_out * self.linear_downsample_rate, + dim_projection) + ] + for _ in range(1, depth): + layers.extend( + [nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) + self.audio_projection = nn.Sequential(*layers) + # NOTE vision-speech tasks use a separate projection layer + layers = [ + nn.Linear(audio_dim_out * self.linear_downsample_rate, + dim_projection) + ] + for _ in range(1, depth): + layers.extend( + [nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) + self.audio_projection_for_vision = nn.Sequential(*layers) + else: + raise NotImplementedError( + f"projection_cls = {projection_cls}, not implemented") + + # TODO: audio sequence compression - Qformer + self.vocab_size = config.vocab_size + self.input_embeds = None + self.audio_embed_sizes = None + + def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None: + self.input_embeds = input_embeds + + def set_audio_embed_sizes(self, + audio_embed_sizes: torch.LongTensor) -> None: + self.audio_embed_sizes = audio_embed_sizes + + def get_audio_features( + self, + input_embeds: torch.FloatTensor, + audio_attention_mask: torch.Tensor = None, + audio_projection_mode: str = "speech", + ): + + if self.freeze_audio_processor: + with torch.no_grad(): + audio_features, masks = self.encoder(input_embeds, + audio_attention_mask) + else: + audio_features, masks = self.encoder(input_embeds, + audio_attention_mask) + + if self.qformer is not None: + audio_features, _ = self.qformer(audio_features, mask=None) + + if self.conv_ds is not None: + if masks is not None: + masks = masks.squeeze(1) + + audio_features, masks = self.conv_ds(audio_features, mask=masks) + + if self.linear_downsample_rate != 1: + bs, seq_len, feat_dim = audio_features.size() + padding = seq_len % self.linear_downsample_rate + if padding > 0: + audio_features = F.pad( + audio_features, + (0, 0, 0, self.linear_downsample_rate - padding), + "constant", + 0, + ) + + seq_len = audio_features.size(1) + audio_features = audio_features.view( + bs, + seq_len // self.linear_downsample_rate, + feat_dim * self.linear_downsample_rate, + ) + + if audio_projection_mode == 'speech': + audio_set_tensor = self.audio_projection(audio_features) + elif audio_projection_mode == 'vision': + audio_set_tensor = self.audio_projection_for_vision(audio_features) + else: + raise ValueError( + f"audio_projection_mode = {audio_projection_mode} not "\ + "implemented" + ) + + return audio_set_tensor + + def forward( + self, + input_ids: torch.LongTensor, + input_embeds: torch.FloatTensor, + audio_embed_sizes, + **kwargs, + ) -> torch.FloatTensor: + """ + arguments: + input_ids: input text ids (B, U) + input_embeds: audio features (B, T, D) B: num audios in a sequence + """ + assert input_embeds is not None and len(input_embeds) == len( + audio_embed_sizes) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + with torch.no_grad(): + positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero( + as_tuple=False) + + if not isinstance(input_embeds, list): + input_embeds = [input_embeds] + + audio_projection_mode = kwargs.get("audio_projection_mode", "speech") + audio_set_tensor = [ + self.get_audio_features( + input_embed, audio_projection_mode=audio_projection_mode) + for input_embed in input_embeds + ] + + with torch.no_grad(): + input_ids.clamp_min_(0).clamp_max_(self.vocab_size) + + if "wte" in kwargs: + # we use the token embedding layer from the huggingface model, this + # is REQUIRED to make sure we are using the loaded weights. + hidden_states = kwargs["wte"](input_ids) + else: + # otherwise, we use token embedding in pretrained mixformer from + # phi team + hidden_states = self.wte(input_ids) + + if len(positions.tolist()) > 0: + assert sum(audio_embed_sizes) == len( + positions + ), "please ensure the encoder outputs have the same length as"\ + " defined in input_ids!" + idx = 0 + for i in range(len(audio_embed_sizes)): + cnt = audio_embed_sizes[i] + assert audio_set_tensor[i].shape[0] == 1 + hidden_states[ + positions[idx, 0], + positions[idx, 1]:positions[idx, 1] + cnt, + ] = (audio_set_tensor[i][0, :audio_embed_sizes[i], :].to( + hidden_states.dtype).to(hidden_states.device)) + idx += cnt + + else: + if self.training: + # hidden_states[:, 0:img_set_tensor.shape[0]] = + # hidden_states[:, 0:img_set_tensor.shape[0]] + + # 0 * img_set_tensor.to(hidden_states.dtype) + # .to(hidden_states.device) + hidden_states[:, 0:1] = hidden_states[:, 0:1] + \ + 0 * audio_set_tensor[:, 0:1].to(hidden_states.dtype)\ + .to(hidden_states.device) + + if self.drop is not None: + hidden_states = self.drop(hidden_states) + return hidden_states diff --git a/vllm/model_executor/models/phi4mm_utils.py b/vllm/model_executor/models/phi4mm_utils.py new file mode 100644 index 0000000000000..16b62c60836e7 --- /dev/null +++ b/vllm/model_executor/models/phi4mm_utils.py @@ -0,0 +1,1969 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +# Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com) +# but implemented by the Phi-Speech team +#!/usr/bin/env python3 +import math +from functools import partial +from typing import Callable, Dict, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, checkpoint_wrapper, offload_wrapper) + + +class Block(nn.Module): + """Block abstract module""" + + def __init__(self, input_size, output_size): + super().__init__() + self.input_size = input_size + self.output_size = output_size + + +def get_activation(name="relu"): + """Select an activation function by name + + Args: + name: str + activation function name, + one of ["relu", "gelu", "swish", "sigmoid"], + default "relu". + """ + name = name.lower() + if name == "relu": + return nn.ReLU(inplace=True) + if name == "gelu": + return nn.GELU() + if name == "swish": + return Swish() + if name == "sigmoid": + return torch.nn.Sigmoid() + return nn.Identity() + + +def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): + """ + The function is very important for Transformer Transducer Streaming mode + Args: + xs_len (int): sequence length + chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. + It also supports adaptive chunk size [0,10,15,45] + left_window (int): how many left chunks can be seen + right_window (int): how many right chunks can be seen. It is used for + chunk overlap model. + Returns: + mask (torch.Tensor): a mask tensor for streaming model + Torch 1.0.1 + tensor([[1., 1., 0., 0.], + [0., 1., 1., 0.], + [0., 0., 1., 1.]]) + Torch 1.4.1 + tensor([[True., True., False., False.], + [False., True., True., False.], + [False., False., True., True.]]) + """ + chunk_start_idx = torch.Tensor(chunk_start_idx).long( + ) # first idx of each chunk, such as [0,18,36,48]. + start_pad = torch.nn.functional.pad( + chunk_start_idx, + (1, 0)) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] + end_pad = torch.nn.functional.pad( + chunk_start_idx, (0, 1), value=x_len + ) # append x_len to the end, so it becomes [0,18,36,48, x_len] + seq_range = torch.arange(0, + x_len).unsqueeze(-1) # seq_range size: [x_len, 1] + idx = ((seq_range < end_pad) & + (seq_range >= start_pad)).nonzero()[:, 1] # idx size: [x_len] + # boundary = end_pad[idx] # boundary size: [x_len] + seq_range_expand = (torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + ) # seq_range_expand size [x_len, x_len] + idx_left = idx - left_window + idx_left[idx_left < 0] = 0 + boundary_left = start_pad[idx_left] + mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) + idx_right = idx + right_window + idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) + boundary_right = end_pad[idx_right] + mask_right = seq_range_expand < boundary_right.unsqueeze(-1) + return mask_left & mask_right + + +class Swish(nn.Module): + """Implement Swish activation module. + From https://arxiv.org/pdf/2005.03191.pdf + + """ + + def __init__(self) -> None: + super().__init__() + self.act_fn = nn.Sigmoid() + + def forward(self, x: Tensor) -> Tensor: + """Apply Swish function + + Args: + x: torch.Tensor + Input. + """ + return x * self.act_fn(x) + + +class GLU(nn.Module): + """Implement Gated Linear Unit (GLU) module""" + + def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None: + super().__init__() + self.dim = dim + self.act_name = act_name.lower() + + if self.act_name == "relu": + self.act_fn = nn.ReLU(inplace=True) + elif self.act_name == "gelu": + self.act_fn = nn.GELU() + elif self.act_name == "swish": + self.act_fn = Swish() + elif self.act_name == "sigmoid": + self.act_fn = nn.Sigmoid() + else: + self.act_fn = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """GLU forward + Apply Swish function on the first half of input matrices + with sigmoid of the second half. + + Args: + x: torch.Tensor + Input. + + """ + half_x, gate = x.chunk(2, dim=self.dim) + return half_x * self.act_fn(gate) + + +# TODO: Abdel, this can be improved using GLU module +class GLUPointWiseConv(nn.Module): + """GLUPointWiseConv module + used for conformer architecture, + for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + output_dim: int + output channel size. + kernel_size: int + kernel size + glu_type: str, optional + activation function one of + ["sigmoid", "relu", "gelu"] + default "sigmoid". + bias_in_glu: bool, optional + use addtive bias in glu + causal: bool, optional + if set to True, padding is set to the half of + kernel size, ie, convolution can't see future frames. + default False. + + """ + + def __init__( + self, + input_dim, + output_dim, + kernel_size, + glu_type="sigmoid", + bias_in_glu=True, + causal=False, + ): + super().__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + self.bias_in_glu = bias_in_glu + if causal: + self.ext_pw_conv_1d = nn.Conv1d( + input_dim, + output_dim * 2, + kernel_size, + 1, + padding=(kernel_size - 1), + ) + else: + self.ext_pw_conv_1d = nn.Conv1d( + input_dim, + output_dim * 2, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ) + + if glu_type == "sigmoid": + self.glu_act = nn.Sigmoid() + elif glu_type == "relu": + self.glu_act = nn.ReLU() + elif glu_type == "gelu": + self.glu_act = nn.GELU() + elif glu_type == "swish": + self.glu_act = Swish() + else: + raise ValueError(f"Unsupported activation type {self.glu_act}") + + if bias_in_glu: + self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1)) + self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1)) + + def forward(self, x): + """ + Args: + x: torch.Tensor + input tensor + """ + # to be consistent with GLULinear, we assume the input always has the + # #channel (#dim) in the last dimension of the tensor, so need to + # switch the dimension first for 1D-Conv case + x = x.permute([0, 2, 1]) + x = self.ext_pw_conv_1d(x) + if self.glu_type == "bilinear": + if self.bias_in_glu: + x = (x[:, 0:self.output_dim, :] + self.b1) * ( + x[:, self.output_dim:self.output_dim * 2, :] + self.b2) + else: + x = (x[:, 0:self.output_dim, :]) * ( + x[:, self.output_dim:self.output_dim * 2, :]) + else: + if self.bias_in_glu: + x = (x[:, 0:self.output_dim, :] + self.b1) * self.glu_act( + x[:, self.output_dim:self.output_dim * 2, :] + self.b2) + else: + x = (x[:, 0:self.output_dim, :]) * self.glu_act( + x[:, self.output_dim:self.output_dim * 2, :]) + + x = x.permute([0, 2, 1]) + return x + + +class DepthWiseSeperableConv1d(nn.Module): + """DepthWiseSeperableConv1d module used in Convnet module + for the conformer, for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + depthwise_seperable_out_channel: int + if set different to 0, the number of + depthwise_seperable_out_channel will be used as a channel_out + of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + kernel_size: int + kernel_size + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + padding: int, optional + padding for the conv1d, + default: 0. + + """ + + def __init__( + self, + input_dim, + depthwise_seperable_out_channel, + kernel_size, + depthwise_multiplier, + padding=0, + ): + super().__init__() + + self.dw_conv = nn.Conv1d( + input_dim, + input_dim * depthwise_multiplier, + kernel_size, + 1, + padding=padding, + groups=input_dim, + ) + + if depthwise_seperable_out_channel != 0: + self.pw_conv = nn.Conv1d( + input_dim * depthwise_multiplier, + depthwise_seperable_out_channel, + 1, + 1, + 0, + ) + else: + self.pw_conv = nn.Identity() + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + + def forward(self, x): + """ + + Args: + x: torch.Tensor + input tensor + """ + x = self.dw_conv(x) + if self.depthwise_seperable_out_channel != 0: + x = self.pw_conv(x) + return x + + +class ConvModule(nn.Module): + """ConvModule Module for the conformer block. + for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + ext_pw_out_channel: int + if > 0, ext_pw_out_channel is a dim channel size + for the last pointwise conv after swish activation. + depthwise_seperable_out_channel: int + if set different to 0, the number of + depthwise_seperable_out_channel + will be used as a channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + ext_pw_kernel_size: int + kernel size of the conv pointwise of the conformer. + kernel_size: int + kernel size. + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + dropout_rate: float + dropout rate. + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation. + default False + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + chunk_size: int, optional + chunk size for cnn. default 18 + activation: str, optional + activation function used in ConvModule, + default: "relu". + glu_type: str, optional + activation function used for the glu, + default: "sigmoid". + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + export: bool, optional, + if set to True, padding is equal to 0. This is for inference, + or onnx export. Typically this is set by the export program or + the decoder program, and it isn't present in your config file. + default False + """ + + def __init__( + self, + input_dim, + ext_pw_out_channel, + depthwise_seperable_out_channel, + ext_pw_kernel_size, + kernel_size, + depthwise_multiplier, + dropout_rate, + causal=False, + batch_norm=False, + chunk_se=0, + chunk_size=18, + activation="relu", + glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + export=False, + ): + super().__init__() + self.layer_norm = nn.LayerNorm(input_dim) + self.input_dim = input_dim + self.ext_pw_out_channel = ext_pw_out_channel + self.ext_pw_kernel_size = ext_pw_kernel_size + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + self.glu_type = glu_type + self.bias_in_glu = bias_in_glu + self.linear_glu_in_convm = linear_glu_in_convm + self.causal = causal + + self._add_ext_pw_layer() + + self.batch_norm = batch_norm + self.kernel_size = kernel_size + + if batch_norm: + self.bn_layer = nn.BatchNorm1d(input_dim) + + self.act = get_activation(activation) + self.dropout = nn.Dropout(dropout_rate) + self.export = export + + if causal: + padding = 0 if export else kernel_size - 1 + else: + padding = (kernel_size - 1) // 2 + + self.dw_sep_conv_1d = DepthWiseSeperableConv1d( + input_dim, + depthwise_seperable_out_channel, + kernel_size, + depthwise_multiplier, + padding=padding, + ) + + if depthwise_seperable_out_channel != 0: + if input_dim != depthwise_seperable_out_channel: + self.ln2 = nn.Linear(depthwise_seperable_out_channel, + input_dim) + else: + if depthwise_multiplier != 1: + self.ln2 = nn.Linear(input_dim * depthwise_multiplier, + input_dim) + + def _add_ext_pw_layer(self): + """ + This function is an extension of __init__ function + and dedicated to the convolution module creation + of the conformer. + """ + self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = ( + nn.Identity()) # jit hacks. + self.squeeze_excitation = nn.Identity() # jit. + self.apply_ln1 = self.fix_len1 = False # jit. + + if self.ext_pw_out_channel != 0: + if self.causal: + self.ext_pw_conv_1d = nn.Conv1d( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + 1, + padding=(self.ext_pw_kernel_size - 1), + ) + if self.ext_pw_kernel_size > 1: + self.fix_len1 = True + else: + self.fix_len1 = False + else: + self.ext_pw_conv_1d = nn.Conv1d( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + 1, + padding=(self.ext_pw_kernel_size - 1) // 2, + ) + self.fix_len1 = False + + if self.linear_glu_in_convm: + self.glu = GLULinear( + self.input_dim, + self.ext_pw_out_channel, + self.glu_type, + self.bias_in_glu, + ) + else: + self.glu = GLUPointWiseConv( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + self.glu_type, + self.bias_in_glu, + self.causal, + ) + + if self.input_dim != self.ext_pw_out_channel: + self.apply_ln1 = True + self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim) + else: + self.apply_ln1 = False + else: + self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3)) + self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3)) + + def forward(self, x): + """ConvModule Forward. + + Args: + x: torch.Tensor + input tensor. + """ + x = self.layer_norm(x) + + if self.ext_pw_out_channel != 0: + x = self.glu(x) + if self.causal and self.ext_pw_kernel_size > 1: + x = x[:, :-(self.ext_pw_kernel_size - 1), :] + if self.apply_ln1: + x = self.ln1(x) + else: + x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0] + x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1] + x = x_0 + x_1 + + x = x.permute([0, 2, 1]) + + x = self.dw_sep_conv_1d(x) + if self.causal and self.kernel_size > 1: + x = x[:, :, :-(self.kernel_size - 1)] + if hasattr(self, "ln2"): + x = x.permute([0, 2, 1]) + x = self.ln2(x) + x = x.permute([0, 2, 1]) + if self.batch_norm: + x = self.bn_layer(x) + x = self.act(x) + + if self.ext_pw_out_channel != 0: + x = self.ext_pw_conv_1d(x) + if self.fix_len1: + x = x[:, :, :-(self.ext_pw_kernel_size - 1)] + + if self.apply_ln1: + x = x.permute([0, 2, 1]) + x = self.ln1(x) + x = x.permute([0, 2, 1]) + + x = x.permute([0, 2, 1]) + else: + x = x.unsqueeze(1).permute([0, 1, 3, 2]) + x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2] + x = x.squeeze(1) + + x = self.dropout(x) + return x + + +class GLULinear(nn.Module): + """Linear + GLU module + + Args: + input_dim: int + input size + output_dim: int + output size. + glu_type: + activation function name used in glu module. + default "sigmoid" (swish function). + bias_in_glu: bool, optional + If True, the addtive bias is added. Default False. + """ + + def __init__( + self, + input_dim, + output_dim, + glu_type="sigmoid", + bias_in_glu=True, + ): + super().__init__() + self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu) + self.glu_act = GLU(-1, glu_type) + + def forward(self, x): + """GLULinear forward + + Args: + x: torch.Tensor + inpute tensor. + """ + x = self.linear(x) + return self.glu_act(x) + + +class FeedForward(nn.Module): + """FeedForward Module. + For more details see Conformer paper: + https://arxiv.org/pdf/2005.08100.pdf + + Args: + d_model: int + input size. + d_inner: int + output size. + dropout_rate: float, + dropout rate. + activation: str, + activation function name, + one of ["relu", "swish", "sigmoid"], + sigmoid activation is only used with "glu_in_fnn=True", + default "sigmoid". + bias_in_glu: bool, optional + """ + + def __init__( + self, + d_model, + d_inner, + dropout_rate, + activation="sigmoid", + bias_in_glu=True, + ): + super().__init__() + self.d_model = d_model + self.d_inner = d_inner + + self.layer_norm = nn.LayerNorm(d_model) + module = GLULinear(d_model, d_inner, activation, bias_in_glu) + self.net = nn.Sequential( + module, + nn.Dropout(dropout_rate), + nn.Linear(d_inner, d_model), + nn.Dropout(dropout_rate), + ) + + def forward(self, x): + """FeedForward forward function. + + Args: + x: torch.Tensor + input tensor. + """ + out = self.net(self.layer_norm(x)) + + return out + + +#### positional encoding starts here +def _pre_hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + """Perform pre-hook in load_state_dict for backward compatibility. + + Note: + We saved self.pe until v.0.5.2 but we have omitted it later. + Therefore, we remove the item "pe" from `state_dict` for backward + compatibility. + + """ + k = prefix + "pe" + if k in state_dict: + state_dict.pop(k) + + +class T5RelativeAttentionLogitBias(nn.Module): + """ + This module implements the relative position bias described in Section + 2.1 of the T5 paper: https://arxiv.org/pdf/1910.10683.pdf + + The Huggingface implementation is used as a reference + https://github.com/huggingface/transformers/blob/v4.30.0/src/ + transformers/models/t5/modeling_t5.py#L435 + + Modifies attention as Q*K^T + B, where B is a learned scalar bias based + on relative position of the query and key. It is HxNxN, where H is the + number of heads, N is the sequence length. + + I've made these modifications to the original T5 bias: + - Skipping of the bucketing step. Original T5 bias converted rel + position distances into logarithmically increasing buckets. This is + supposed to help with length generalization. + - I just directly use rel position index as bias values, as we don't + need length generalization (40s max is good enough for ASR encoder), + and it keeps ONNX export simple. + - I've also extended it so that biases can be asymmetric, the default + implementation treats L->R and R->L the same. Asymmetric was found to + yield better results in my experiments. + + Args: + num_heads: int + Number of attention heads + num_buckets: int + Number of buckets to use for relative attention bias. This is the + size of the learnable bias parameter. Bucketing is not yet + supported, so this defaults to -1 which means no bucketing is + used (max_distance determines size of bias param). + max_distance: int + Maximum distance to use for relative attention bias. With + num_buckets=-1, this directly controls the max size of the bias + parameter. When num_buckets > 0 is supported, this will control + the maximum distance for logarithmic bucketing after which all + positions are in the same bucket. + symmetric: bool + Whether to use symmetric or asymmetric biases. symmetric=False uses + 2x number of bias params to distinguish L->R from R->L. This was + found to be better for the encoder. + """ + + def __init__(self, + num_heads, + num_buckets=-1, + max_distance=1000, + symmetric=False): + super().__init__() + self.num_heads = num_heads + self.num_buckets = num_buckets + self.max_distance = max_distance + self.symmetric = symmetric + self._skip_bucketing = self.num_buckets < 0 + if self._skip_bucketing: + self.num_buckets = max_distance + else: + raise NotImplementedError( + "T5 attention bias with bucketed positions is not yet tested") + if not self.symmetric: + self.num_buckets *= 2 + self.bias_values = nn.Embedding(self.num_buckets, self.num_heads) + + def forward(self, x): + # instantiate bias compatible with shape of x + maxpos = x.size(1) + context_position = torch.arange(maxpos, + device=x.device, + dtype=torch.long)[:, None] + memory_position = torch.arange(maxpos, + device=x.device, + dtype=torch.long)[None, :] + relative_position = memory_position - context_position + # clipping to a maximum distance using ops that play well with ONNX + # export + relative_position = relative_position.masked_fill( + relative_position < -self.max_distance, -self.max_distance) + relative_position = relative_position.masked_fill( + relative_position > self.max_distance - 1, self.max_distance - 1) + + # mapping from relative position to index in the bias parameter + if self._skip_bucketing: + bias_idx = relative_position + else: + bias_idx = self._bucket_relative_position(relative_position) + if self.symmetric: + bias_idx = bias_idx.abs() + else: + bias_idx += self.num_buckets // 2 + + t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H] + t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze( + 0) # [1, H, L, L] + + return t5_rel_att_bias + + def _bucket_relative_position(self, relative_position): + # this is a placeholder (isn't tested, likely buggy) using HuggingFace + # implem as a reference this also needs to be extended to support + # asymmetric +/- ve positions + relative_buckets = 0 + if not self.causal: + self.num_buckets //= 2 + relative_buckets += (relative_position > 0).to( + torch.long) * self.num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, + torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = self.num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in + # positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) / + math.log(self.max_distance / max_exact) * + (self.num_buckets - max_exact)).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, self.num_buckets - 1), + ) + + relative_buckets += torch.where(is_small, relative_position, + relative_position_if_large) + return relative_buckets + + +class AbsolutePositionalEncoding(nn.Module): + """Absolute Positional encoding module. + This module implement Absolute sinusoidal positional encoding + from: https://arxiv.org/pdf/1706.03762.pdf + + Args: + d_model: int + Input embedding size. + dropout_rate: float + dropout rate + max_len: int, optional + Maximum input length sequence, Default 5000 + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Construct an PositionalEncoding object.""" + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self._register_load_state_dict_pre_hook(_pre_hook) + + def extend_pe(self, x): + """Reset the positional encodings. + + Args: + x: torch.Tensor + """ + if self.pe is not None and self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) * + -(math.log(10000.0) / self.d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + + Args: + x: torch.Tensor + Input tensor. shape is (batch, time, ...) + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, :x.size(1)] + return self.dropout(x) + + +#### forward embedding layers starts here +class MeanVarianceNormLayer(nn.Module): + """Mean/variance normalization layer. + + Will subtract mean and multiply input by inverted standard deviation. + Typically used as a very first layer in a model. + + Args: + input_size: int + layer input size. + """ + + def __init__(self, input_size): + super().__init__() + self.input_size = input_size + self.register_buffer("global_mean", torch.zeros(input_size)) + self.register_buffer("global_invstd", torch.ones(input_size)) + self.global_mean: Optional[Tensor] + self.global_invstd: Optional[Tensor] + + def forward(self, input_: Tensor) -> Tensor: + """MeanVarianceNormLayer Forward + + Args: + input_: torch.Tensor + input tensor. + """ + return (input_ - self.global_mean) * self.global_invstd + + +class CausalConv1D(nn.Conv1d): + """ + A causal version of nn.Conv1d where each step would have limited access to + locations on its right or left + All arguments are the same as nn.Conv1d except padding. + + If padding is set None, then paddings are set automatically to make it a + causal convolution where each location would not see any steps on its right. + + If padding is set as a list (size of 2), then padding[0] would be used as + left padding and padding[1] as right padding. + It would make it possible to control the number of steps to be accessible + on the right and left. + This mode is not supported when stride > 1. padding[0]+padding[1] should + be equal to (kernel_size - 1). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: Union[str, int] = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + self.cache_drop_size = None + if padding is None: + self._left_padding = kernel_size - 1 + self._right_padding = stride - 1 + else: + if stride != 1 and padding != kernel_size - 1: + raise ValueError( + "No striding allowed for non-symmetric convolutions!") + if isinstance(padding, int): + self._left_padding = padding + self._right_padding = padding + elif (isinstance(padding, list) and len(padding) == 2 + and padding[0] + padding[1] == kernel_size - 1): + self._left_padding = padding[0] + self._right_padding = padding[1] + else: + raise ValueError(f"Invalid padding param: {padding}!") + + self._max_cache_len = self._left_padding + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def update_cache(self, x, cache=None): + if cache is None: + new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) + next_cache = cache + else: + new_x = F.pad(x, pad=(0, self._right_padding)) + new_x = torch.cat([cache, new_x], dim=-1) + if self.cache_drop_size > 0: + next_cache = new_x[:, :, :-self.cache_drop_size] + else: + next_cache = new_x + next_cache = next_cache[:, :, -cache.size(-1):] + return new_x, next_cache + + def forward(self, x, cache=None): + x, cache = self.update_cache(x, cache=cache) + x = super().forward(x) + if cache is None: + return x + else: + return x, cache + + +class CausalConv2D(nn.Conv2d): + """ + A causal version of nn.Conv2d where each location in the 2D matrix would + have no access to locations on its right or down + All arguments are the same as nn.Conv2d except padding which should be + set as None + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: Union[str, int] = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + if padding is not None: + raise ValueError( + "Argument padding should be set to None for CausalConv2D.") + self._left_padding = kernel_size - 1 + self._right_padding = stride - 1 + + padding = 0 + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + + def forward( + self, + x, + ): + if self.training: + x = F.pad( + x, + pad=( + self._left_padding, + self._right_padding, + self._left_padding, + self._right_padding, + ), + ) + else: + x = F.pad( + x, + pad=(self._left_padding, self._right_padding, 0, 0), + ) + x = super().forward(x) + return x + + +class NemoConvSubsampling(torch.nn.Module): + """Convlutional subsampling module, taken from NeMo ASR + (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a + 34501479cf/nemo/collections/asr/parts/submodules/subsampling.py) + + Striding Subsampling: "Speech-Transformer: A No-Recurrence + Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong + et al. (https://ieeexplore.ieee.org/document/8462506) + + + Compared with the EncoderConv2D (`input_layer: custom`), this is a + much simplified approach, and uses no LayerNorm and far fewer Conv2Ds. + Moreover, depthwise convolutions are used to reduce FLOPs, but the first + layer is kept as a regular convolution so as not to degrade accuracy. + + `Striding` and `dw_striding` are the same except that the latter uses + depthwise convolutions after the first layer, whereas the former does not. + + Args: + subsampling_factor (int): Time reduction factor + feat_in (int): size of the input features + feat_out (int): size of the output features + subsampling (str): The subsampling technique, choose from + {"striding", "dw-striding", "striding_conv1d", + "dw_striding_conv1d"} + conv_channels (int): Number of channels for the convolution layers, + default is 256. + subsampling_conv_chunking_factor (int): Input chunking factor which + can be -1 (no chunking) 1 (auto) or a power of 2. Default is 1 + activation (Module): activation function, default is nn.ReLU() + is_causal (bool): whether to use causal Conv1/2D, where each step will + have limited access to locations on its right or left + """ + + def __init__( + self, + feat_in, + feat_out, + subsampling_factor=4, + subsampling="dw_striding", + conv_channels=256, + subsampling_conv_chunking_factor=1, + activation=nn.ReLU(), # noqa: B008 + is_causal=False, + ): + super().__init__() + self._subsampling = subsampling + self._conv_channels = conv_channels + self._feat_in = feat_in + self._feat_out = feat_out + + if subsampling_factor % 2 != 0: + raise ValueError("Sampling factor should be a multiply of 2!") + self._sampling_num = int(math.log(subsampling_factor, 2)) + self.subsampling_factor = subsampling_factor + self.is_causal = is_causal + self.subsampling_causal_cond = subsampling in ( + "dw_striding", + "striding", + "striding_conv1d", + ) + + if (subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0): + raise ValueError( + "subsampling_conv_chunking_factor should be -1, 1, or a "\ + "power of 2" + ) + self.subsampling_conv_chunking_factor = \ + subsampling_conv_chunking_factor + + in_channels = 1 + layers = [] + + if subsampling == "dw_striding": + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + # Layer 1 + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + )) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + )) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + groups=in_channels, + )) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + )) + + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + )) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "striding": + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + for i in range(self._sampling_num): + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + )) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + )) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "striding_conv1d": + in_channels = feat_in + + self._stride = 2 + self._kernel_size = 5 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + for i in range(self._sampling_num): + if self.is_causal: + layers.append( + CausalConv1D( + in_channels=in_channels, + out_channels=(feat_out if self._sampling_num == i + + 1 else conv_channels), + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + )) + else: + layers.append( + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=(feat_out if self._sampling_num == i + + 1 else conv_channels), + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + )) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "dw_striding_conv1d": + in_channels = feat_in + + self._stride = 2 + self._kernel_size = 5 + self._ceil_mode = False + + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + + # Layer 1 + layers.extend([ + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ), + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=(feat_out if self._sampling_num == 1 else + conv_channels), + kernel_size=1, + stride=1, + padding=0, + groups=1, + ), + ]) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + layers.extend([ + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ), + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=(feat_out if self._sampling_num == i + + 2 else conv_channels), + kernel_size=1, + stride=1, + padding=0, + groups=1, + ), + ]) + layers.append(activation) + in_channels = conv_channels + + else: + raise ValueError(f"Not valid sub-sampling: {subsampling}!") + + if subsampling in ["dw_striding", "striding"]: + in_length = torch.tensor(feat_in, dtype=torch.float) + out_length = calc_length( + lengths=in_length, + all_paddings=self._left_padding + self._right_padding, + kernel_size=self._kernel_size, + stride=self._stride, + ceil_mode=self._ceil_mode, + repeat_num=self._sampling_num, + ) + self.out = torch.nn.Linear(conv_channels * int(out_length), + feat_out) + self.conv2d_subsampling = True + elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]: + self.out = None + self.conv2d_subsampling = False + else: + raise ValueError(f"Not valid sub-sampling: {subsampling}!") + + self.conv = torch.nn.Sequential(*layers) + + def get_sampling_frames(self): + return [1, self.subsampling_factor] + + def get_streaming_cache_size(self): + return [0, self.subsampling_factor + 1] + + def forward(self, x, mask): + """ + Forward method for NeMo subsampling. + + Args: + x[Batch, Time, Filters]: torch.Tensor + input tensor + x_mask: torch.Tensor + input mask + + Returns: + x: torch.Tensor + Resulting tensor from subsampling (B, T // + time_reduction_factor, feat_out) + pad_mask: torch.Tensor + tensor of padded hidden state sequences (B, 1, T // + time_reduction_factor) + """ + x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2) + + # split inputs if chunking_factor is set + if (self.subsampling_conv_chunking_factor != -1 + and self.conv2d_subsampling): + if self.subsampling_conv_chunking_factor == 1: + # if subsampling_conv_chunking_factor is 1, we split only + # if needed. + # avoiding a bug / feature limiting indexing of tensors + # to 2**31. + # see https://github.com/pytorch/pytorch/issues/80020 + x_ceil = (2**31 / self._conv_channels * self._stride * + self._stride) + need_to_split = torch.numel(x) > x_ceil + else: + # if subsampling_conv_chunking_factor > 1 we always split + need_to_split = True + + if need_to_split: + x, success = self.conv_split_by_batch(x) + if not success: # if unable to split by batch, try by channel + if self._subsampling == "dw_striding": + x = self.conv_split_by_channel(x) + else: + x = self.conv(x) # try anyway + else: + x = self.conv(x) + else: + x = self.conv(x) + + # Flatten Channel and Frequency Axes + if self.conv2d_subsampling: + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).reshape(b, t, -1)) + # Transpose to Channel Last mode + else: + x = x.transpose(1, 2) + + if mask is None: + return x, None + + max_audio_length = x.shape[1] + feature_lens = mask.sum(1) + padding_length = torch.ceil(feature_lens / self.subsampling_factor) + if self.is_causal and self.subsampling_causal_cond: + feature_lens_remainder = feature_lens % self.subsampling_factor + padding_length[feature_lens_remainder != 1] += 1 + pad_mask = torch.arange(0, max_audio_length, device=x.device).expand( + padding_length.size(0), -1) < padding_length.unsqueeze(1) + return x, pad_mask.unsqueeze(1) + + def reset_parameters(self): + # initialize weights + if self._subsampling == "dw_striding": + with torch.no_grad(): + # init conv + scale = 1.0 / self._kernel_size + dw_max = (self._kernel_size**2)**-0.5 + pw_max = self._conv_channels**-0.5 + + torch.nn.init.uniform_(self.conv[0].weight, -scale, scale) + torch.nn.init.uniform_(self.conv[0].bias, -scale, scale) + + for idx in range(2, len(self.conv), 3): + torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, + dw_max) + torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, + dw_max) + torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, + pw_max) + torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, + pw_max) + + # init fc (80 * 64 = 5120 from https://github.com/kssteven418/ + # Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/ + # src/models/conformer_encoder.py#L487 + fc_scale = (self._feat_out * self._feat_in / + self._sampling_num)**-0.5 + torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) + torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) + + def conv_split_by_batch(self, x): + """Tries to split input by batch, run conv and concat results""" + b, _, _, _ = x.size() + if b == 1: # can't split if batch size is 1 + return x, False + + if self.subsampling_conv_chunking_factor > 1: + cf = self.subsampling_conv_chunking_factor + else: + # avoiding a bug / feature limiting indexing of tensors to 2**31 + # see https://github.com/pytorch/pytorch/issues/80020 + x_ceil = 2**31 / self._conv_channels * self._stride * self._stride + p = math.ceil(math.log(torch.numel(x) / x_ceil, 2)) + cf = 2**p + + new_batch_size = b // cf + if new_batch_size == 0: # input is too big + return x, False + + return ( + torch.cat([ + self.conv(chunk) + for chunk in torch.split(x, new_batch_size, 0) + ]), + True, + ) + + def conv_split_by_channel(self, x): + """For dw convs, tries to split input by time, run conv and concat + results""" + x = self.conv[0](x) # full conv2D + x = self.conv[1](x) # activation + + for i in range(self._sampling_num - 1): + _, c, t, _ = x.size() + + if self.subsampling_conv_chunking_factor > 1: + cf = self.subsampling_conv_chunking_factor + else: + # avoiding a bug / feature limiting indexing of tensors + # to 2**31 + # see https://github.com/pytorch/pytorch/issues/80020 + p = math.ceil(math.log(torch.numel(x) / 2**31, 2)) + cf = 2**p + + new_c = int(c // cf) + if new_c == 0: + new_c = 1 + + new_t = int(t // cf) + if new_t == 0: + new_t = 1 + + x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c, + x) # conv2D, depthwise + + # splitting pointwise convs by time + x = torch.cat( + [ + self.conv[i * 3 + 3](chunk) + for chunk in torch.split(x, new_t, 2) + ], + 2, + ) # conv2D, pointwise + x = self.conv[i * 3 + 4](x) # activation + return x + + def channel_chunked_conv(self, conv, chunk_size, x): + """Performs channel chunked convolution""" + + ind = 0 + out_chunks = [] + for chunk in torch.split(x, chunk_size, 1): + step = chunk.size()[1] + + if self.is_causal: + chunk = nn.functional.pad( + chunk, + pad=( + self._kernel_size - 1, + self._stride - 1, + self._kernel_size - 1, + self._stride - 1, + ), + ) + ch_out = nn.functional.conv2d( + chunk, + conv.weight[ind:ind + step, :, :, :], + bias=conv.bias[ind:ind + step], + stride=self._stride, + padding=0, + groups=step, + ) + else: + ch_out = nn.functional.conv2d( + chunk, + conv.weight[ind:ind + step, :, :, :], + bias=conv.bias[ind:ind + step], + stride=self._stride, + padding=self._left_padding, + groups=step, + ) + out_chunks.append(ch_out) + ind += step + + return torch.cat(out_chunks, 1) + + def change_subsampling_conv_chunking_factor( + self, subsampling_conv_chunking_factor: int): + if (subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0): + raise ValueError( + "subsampling_conv_chunking_factor should be -1, 1, or a "\ + "power of 2" + ) + self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + + +def calc_length(lengths, + all_paddings, + kernel_size, + stride, + ceil_mode, + repeat_num=1): + """Calculates the output length of a Tensor passed through a convolution or + max pooling layer""" + add_pad: float = all_paddings - kernel_size + one: float = 1.0 + for i in range(repeat_num): + lengths = (torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + + one) + lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths) + return lengths.to(dtype=torch.int) + + +#### multihead attention starts here +class AttModule(nn.Module): + """Attention abstraction module""" + + def __init__(self): + super().__init__() + self.export_mode = False + + def set_export(self, mode=True): + """set the export mode""" + self.export_mode = mode + + def forward( + self, + x: Tensor, + memory: Optional[Tensor] = None, + pos_emb: Optional[Tensor] = None, + att_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + """AttModule forward + + Args: + x: torch.Tensor + input tensor. + memory: torch.Tensor, optional + memory tensor. + pos_emb: torch.Tensor, optional + positional encoder embedding. + att_mask: torch.Tensor, optional + attention mask tensor. + """ + return x, memory, pos_emb, att_mask + + +class AttBlock(Block, AttModule): + """Attention Block module to support both Attention and Block module.""" + + def memory_dims(self, max_len=False): + """memory dimensions""" + return (1, self.input_size) + + +def masked_softmax( + scores, + mask: Optional[Tensor], +): + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) + scores = scores.masked_fill(mask, -torch.inf) + attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0) # (batch, head, time1, time2) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + return attn + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer with optional relative position embedding + and GLU. + + Args: + n_head: int + the number of heads. + n_feat: int + input size features. + dropout_rate: float + dropout rate. + use_LN: bool + apply layer norm or not + dropout_at_output: bool + whether to apply dropout at output + attention_inner_dim: int, optional + the attention dimension used in the class, + it can be different from the input dimension n_feat. + default: -1 (equal to n_feat). + use_pt_scaled_dot_product_attention: bool, optional + if set True, use pytorch scaled dot product attention in training. + NOTE: this will NOT be used in ONNX decoding due to a lack of + support. In that case, we use the original attention + implementation, which shows no regression. + default: False. + n_value: int, optional + if set to values other than -1, use a different dimension for + value. With the default value (i.e. -1), it is backward compatible. + group_size: int, optional. must divide `n_head` + if group_size > 1: GQA + if group_size = 1: MHA + if group_size = n_head: MQA + """ + + inv_sqrt_d_k: torch.jit.Final[float] + h: torch.jit.Final[int] + h_k: torch.jit.Final[int] + g: torch.jit.Final[int] + + def __init__( + self, + n_head, + n_feat, + dropout_rate, + attention_inner_dim=-1, + glu_type="swish", + bias_in_glu=True, + use_pt_scaled_dot_product_attention=False, + n_value=-1, + group_size: int = 1, + ): + super().__init__() + if n_value == -1: + n_value = n_feat + if attention_inner_dim == -1: + attention_inner_dim = n_feat + assert attention_inner_dim % n_head == 0 + + # We assume d_v always equals d_k + self.d_k = attention_inner_dim // n_head + self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k) + self.h = n_head + assert n_head % group_size == 0, "group_size must divide n_head" + self.g = group_size + self.h_k = n_head // group_size + + self.linear_q = nn.Linear(n_feat, attention_inner_dim) + self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size) + self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size) + self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value) + + self.attn = torch.jit.Attribute(None, Optional[Tensor]) + self.dropout = nn.Dropout(p=dropout_rate) + self.dropout_rate = dropout_rate + self.use_pt_scaled_dot_product_attention = ( + use_pt_scaled_dot_product_attention) + + if use_pt_scaled_dot_product_attention and group_size > 1: + raise ValueError("Cannot use PT Scaled Attention with GQA") + + # Torchscript eager quantization. Note that these functions below are + # NOOPs and have very little impact on performance unless quantization + # is enabled. + self.quant_q = torch.ao.quantization.QuantStub() + self.quant_x = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + self.ffunc = torch.ao.nn.quantized.FloatFunctional() + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_k: Tensor, + pos_v: Tensor, + mask: Optional[Tensor], + relative_attention_bias: Optional[Tensor] = None, + ): + """Compute 'Scaled Dot Product Attention'. + + Args: + query: torch.Tensor + query tensor (batch, time1, size) + key: torch.Tensor + key tensor (batch, time2, size) + value: torch.Tensor + value tensor (batch, time1, size) + pos_k: torch.Tensor + key tensor used for relative positional embedding. + pos_v: torch.Tensor + value tensor used for relative positional embedding. + mask: torch.Tensor + mask tensor (batch, time1, time2) + relative_attention_bias: torch.Tensor + bias added to attention logits w.r.t. relative positions + (1, n_head, time1, time2) + """ + n_batch = query.size(0) + + q = self.linear_q(query).view(n_batch, -1, self.h, + self.d_k) # (b, t, d) + k = self.linear_k(key).view(n_batch, -1, self.h_k, + self.d_k) # (b, t, d) + v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k) + q = (q.transpose(1, 2) if self.use_pt_scaled_dot_product_attention + and not torch.jit.is_scripting() else q.transpose(1, 2) * + self.inv_sqrt_d_k) + k = k.transpose(1, 2) # (batch, head_k, time2, d_k) + v = v.transpose(1, 2) # (batch, head_k, time2, d_k) + + if (self.use_pt_scaled_dot_product_attention + and not torch.jit.is_scripting()): + attn_mask = None + if mask is not None: + mask = mask.unsqueeze(1) + if relative_attention_bias is not None: + attn_mask = mask + relative_attention_bias + else: + attn_mask = mask + if mask.dtype != q.dtype: + attn_mask = attn_mask.to(q.dtype) + + with torch.backends.cuda.sdp_kernel(enable_flash=True, + enable_math=True, + enable_mem_efficient=True): + x = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=self.dropout_rate, + ) + else: + if self.h != self.h_k: + q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k) + A = torch.einsum("b g h t d, b h s d -> b h t s", q, k) + else: + A = torch.matmul(q, k.transpose(-2, -1)) + if pos_k is not None: + if self.h != self.h_k: + B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k) + else: + reshape_q = (q.contiguous().view(n_batch * self.h, -1, + self.d_k).transpose(0, 1) + ) # (t1,nh,dk) + B = torch.matmul(reshape_q, + pos_k.transpose(-2, + -1)) # pos_k: (t1,dk,t2) + B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), + pos_k.size(1)) + scores = A + B + else: + scores = A + + if relative_attention_bias is not None: + scores = scores + relative_attention_bias + + attn = masked_softmax(scores, mask) # (batch, head, time1, time2) + + self.attn = attn + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn.to(v.dtype), + v) # (batch, head, time1, d_k) + if pos_v is not None: + reshape_attn = (p_attn.contiguous().view( + n_batch * self.h, pos_v.size(0), + pos_v.size(1)).transpose(0, 1)) # (t1, bh, t2) + + attn_v = (torch.matmul(reshape_attn, pos_v).transpose( + 0, 1).contiguous().view(n_batch, self.h, pos_v.size(0), + self.d_k)) + x = x + attn_v + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, + self.h_k * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + +def validate_checkpointing_config(activation_checkpointing): + """validate activation checkpointing configuration""" + if isinstance(activation_checkpointing, str): + assert activation_checkpointing in ( + "", + "checkpoint", + "offload", + ), "activation_checkpointing has to be a dict or a str in "\ + "('', 'checkpoint', 'offload')." + elif isinstance(activation_checkpointing, dict): + assert activation_checkpointing.get("module", "transformer") in ( + "transformer", + "attention", + ), "module in activation_checkpointing has to be in "\ + "('transformer', 'attention')." + else: + raise ValueError("activation_checkpointing has to be a str"\ + " or dict.") + + +def embedding_checkpoint_wrapper( + activation_checkpointing: Union[str, Dict], ) -> Callable: + """return encoder embedding activation checkpoint wrapper""" + validate_checkpointing_config(activation_checkpointing) + + if isinstance(activation_checkpointing, str): + if activation_checkpointing: + if activation_checkpointing == "offload": + return offload_wrapper + return partial(checkpoint_wrapper) + return lambda x: x + + if isinstance(activation_checkpointing, dict): + enabled = activation_checkpointing.get("embed", False) + if enabled: + offloading = activation_checkpointing.get("offload", False) + if offloading: + return offload_wrapper + impl = (CheckpointImpl.REENTRANT if activation_checkpointing.get( + "reentrant", False) else CheckpointImpl.NO_REENTRANT) + return partial(checkpoint_wrapper, checkpoint_impl=impl) + return lambda x: x + raise ValueError("Invalid activation_checkpointing config") + + +def attn_checkpointing(activation_checkpointing: Union[str, Dict], + i) -> Union[str, Dict]: + """return activation checkpointing config for attention layer""" + if isinstance(activation_checkpointing, str): + return "" + + if isinstance(activation_checkpointing, dict): + target_layer_cls = activation_checkpointing.get( + "module", "transformer") + checkpointing_interval = activation_checkpointing.get("interval", 1) + if target_layer_cls == "attention" and i % checkpointing_interval == 0: + return activation_checkpointing + return "" + + raise ValueError("Invalid activation_checkpointing config") + + +class MultiSequential(torch.nn.Sequential): + """Multi-input multi-output torch.nn.Sequential""" + + @torch.jit.ignore + def forward(self, *args): + """Forward method implementation.""" + for m in self: + args = m(*args) + return args + + +def repeat(repeat_num, module_gen_fn): + """repeat module N times + + :param int repeat_num: repeat time + :param function module_gen_fn: function to generate module + :return: repeated modules + :rtype: MultiSequential + """ + return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)]) + + +def get_offset(input_layer: str, time_reduction: int): + """Get an offset. We will use the offset for determining #frames of a + subsampled feature. + + Args: + input_layer (str): Type of an input layer + time_reduction (int): time reduction factor for downsampling a feature + Returns: + int: offset + """ + if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4: + return 3 + if input_layer in ("conv2d", ) and time_reduction == 6: + return 1 + if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8: + return 7 + return 0 + + +def unfold_tensor(xs_pad, max_seq_len): + """ + For a given tensor with shape of (N, T, D), if sequence length T is + longer than max_seq_len, this function unfold it to a + (NT', max_seq_len, D) where T' is T // max_seq_len. + Args: + xs_pad: N, T, D + """ + _, _, D = xs_pad.shape + xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T + # N x D x 1 x T => N x (D x max_seq_len) x T' + xs_pad = F.unfold( + xs_pad[..., None, :], + kernel_size=(1, max_seq_len), + stride=(1, max_seq_len), + ) + new_bsz, _, slen = xs_pad.shape + # N x D x max_seq_len x T' + xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen) + # N x T' x max_seq_len x D + xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous() + # NT' x max_seq_len x D + xs_pad = xs_pad.view(-1, max_seq_len, D) + return xs_pad diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 3a7fcdcf7b370..74160e2d9ee40 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -182,6 +182,7 @@ "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501 "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), + "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), # [Encoder-decoder] "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 diff --git a/vllm/model_executor/models/vision_siglip_navit.py b/vllm/model_executor/models/vision_siglip_navit.py new file mode 100644 index 0000000000000..3a9597a845ff9 --- /dev/null +++ b/vllm/model_executor/models/vision_siglip_navit.py @@ -0,0 +1,1966 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Siglip model configuration""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Any, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn.init import _calculate_fan_in_and_fan_out +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask +from transformers.modeling_outputs import (BaseModelOutput, + BaseModelOutputWithPooling) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import (ModelOutput, add_start_docstrings, + add_start_docstrings_to_model_forward, logging, + replace_return_docstrings) + +from vllm.platforms import _Backend + +from .vision import get_vit_attn_backend + +logger = logging.get_logger(__name__) + +SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "google/siglip-base-patch16-224": + "https://huggingface.co/google/siglip-base-patch16-224/"\ + "resolve/main/config.json", +} + + +class SiglipTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a + [`SiglipTextModel`]. It is used to instantiate a Siglip text encoder + according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar + configuration to that of the text encoder of the Siglip [google/ + siglip-base-patch16-224](https://huggingface.co/google/siglip-base + -patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the documentation from + [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Siglip text model. Defines the number of + different tokens that can be represented by the `inputs_ids` + passed when calling [`SiglipModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer + in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the + Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 64): + The maximum sequence length that this model might ever be used + with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to + `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the + encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the padding token in the vocabulary. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the vocabulary. + Example: + ```python + >>> from transformers import SiglipTextConfig, SiglipTextModel + >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 + style configuration + >>> configuration = SiglipTextConfig() + >>> # Initializing a SiglipTextModel (with random weights) from the + google/siglip-base-patch16-224 style configuration + >>> model = SiglipTextModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_text_model" + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=64, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + # This differs from `CLIPTokenizer`'s default and from openai/siglip + # See https://github.com/huggingface/transformers/pull/24773# + # issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + _flash_attn_2_enabled=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self._flash_attn_2_enabled = _flash_attn_2_enabled + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, + os.PathLike], + **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["text_config"] + + if "model_type" in config_dict and hasattr( + cls, + "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + "You are using a model of type %s to instantiate a model of " + "type %s. This is not supported for all configurations of " + "models and can yield errors.", config_dict['model_type'], + cls.model_type) + + return cls.from_dict(config_dict, **kwargs) + + +class SiglipVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a + [`SiglipVisionModel`]. It is used to instantiate a + Siglip vision encoder according to the specified arguments, defining the + model architecture. Instantiating a configuration with the defaults will + yield a similar configuration to that of the vision encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/ + siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer + in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the + Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to + `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the + encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and + `"gelu_new"` ``"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + Example: + ```python + >>> from transformers import SiglipVisionConfig, SiglipVisionModel + >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 + style configuration + >>> configuration = SiglipVisionConfig() + >>> # Initializing a SiglipVisionModel (with random weights) from the + google/siglip-base-patch16-224 style configuration + >>> model = SiglipVisionModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + _flash_attn_2_enabled=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self._flash_attn_2_enabled = _flash_attn_2_enabled + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, + os.PathLike], + **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr( + cls, + "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + "You are using a model of type %s to " + "instantiate a model of type %s. This is not" + " supported for all configurations of models and can yield" + " errors.", config_dict['model_type'], cls.model_type) + + return cls.from_dict(config_dict, **kwargs) + + +class SiglipConfig(PretrainedConfig): + r""" + [`SiglipConfig`] is the configuration class to store the configuration of a + [`SiglipModel`]. It is used to instantiate a Siglip model according to the + specified arguments, defining the text model and vision model configs. + Instantiating a configuration with the defaults will yield a similar + configuration to that of the Siglip [google/siglip-base-patch16-224]( + https://huggingface.co/google/siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to + control the model outputs. Read the documentation from + [`PretrainedConfig`] for more information. + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize + [`SiglipTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize + [`SiglipVisionConfig`]. + kwargs (*optional*): + Dictionary of keyword arguments. + Example: + ```python + >>> from transformers import SiglipConfig, SiglipModel + >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 + style configuration + >>> configuration = SiglipConfig() + >>> # Initializing a SiglipModel (with random weights) from the + google/siglip-base-patch16-224 style configuration + >>> model = SiglipModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + >>> # We can also initialize a SiglipConfig from a SiglipTextConfig + and a SiglipVisionConfig + >>> from transformers import SiglipTextConfig, SiglipVisionConfig + >>> # Initializing a SiglipText and SiglipVision configuration + >>> config_text = SiglipTextConfig() + >>> config_vision = SiglipVisionConfig() + >>> config = SiglipConfig.from_text_vision_configs(config_text, + config_vision) + ```""" + + model_type = "siglip" + + def __init__(self, text_config=None, vision_config=None, **kwargs): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info( + "`text_config` is `None`. Initializing the `SiglipTextConfig`" + " with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the " + "`SiglipVisionConfig` with default values.") + + self.text_config = SiglipTextConfig(**text_config) + self.vision_config = SiglipVisionConfig(**vision_config) + + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: SiglipTextConfig, + vision_config: SiglipVisionConfig, **kwargs): + r""" + Instantiate a [`SiglipConfig`] (or a derived class) from siglip text + model configuration and siglip vision + model configuration. + Returns: + [`SiglipConfig`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), + vision_config=vision_config.to_dict(), + **kwargs) + + +# coding=utf-8 +# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Siglip model.""" + +_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" + +SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "google/siglip-base-patch16-224", + # See all SigLIP models at https://huggingface.co/models?filter=siglip +] + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official + # releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/ + # truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) # noqa + u = norm_cdf((b - mean) / std) # noqa + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + if tensor.dtype in [torch.float16, torch.bfloat16]: + # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu + og_dtype = tensor.dtype + tensor = tensor.to(torch.float32) + tensor.erfinv_() + tensor = tensor.to(og_dtype) + else: + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + if tensor.dtype == torch.float16: + # The `clamp_` op is not (yet?) defined in float16+cpu + tensor = tensor.to(torch.float32) + tensor.clamp_(min=a, max=b) + tensor = tensor.to(torch.float16) + else: + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_(tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where + the bounds [a, b] are applied when sampling the normal distribution with + mean=0, std=1.0 and the result is subsequently scaled and shifted by the + mean and std args. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with +# CLIP->Siglip +class SiglipVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings + of the pooling of the last hidden states. + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` + *optional* returned when model is initialized with + `with_projection=True`): + The image embeddings obtained by applying the projection layer to + the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the + model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the + optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with +# CLIP->Siglip +class SiglipTextModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the + last hidden states. + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` + *optional* returned when model is initialized with + `with_projection=True`): + The text embeddings obtained by applying the projection layer to + model. + the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the + embeddings, if the model has an embedding layer, + one for the + output of each layer) of shape `(batch_size, sequence_length, + hidden_size)`. + Hidden-states of the model at the output of each layer plus the + optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights after the attention softmax, used to compute + the weighted average in the self-attention heads. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with +# CLIP->Siglip +class SiglipOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when + `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, + text_batch_size)`): + The scaled dot product scores between `image_embeds` and + `text_embeds`. This represents the image-text similarity scores. + logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, + image_batch_size)`): + The scaled dot product scores between `text_embeds` and + `image_embeds`. This represents the text-image similarity scores. + text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to + the pooled output of [`SiglipTextModel`]. + image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to + the pooled output of [`SiglipVisionModel`]. + text_model_output(`BaseModelOutputWithPooling`): + The output of the [`SiglipTextModel`]. + vision_model_output(`BaseModelOutputWithPooling`): + The output of the [`SiglipVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: torch.FloatTensor = None + logits_per_text: torch.FloatTensor = None + text_embeds: torch.FloatTensor = None + image_embeds: torch.FloatTensor = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output" + ] else getattr(self, k).to_tuple() + for k in self.keys()) + + +class SiglipVisionEmbeddings(nn.Module): + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) + + def forward(self, pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor) -> torch.Tensor: + batch_size = pixel_values.size(0) + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, \ + max_im_w // self.patch_size + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, + 1 / self.num_patches_per_side) + position_ids = torch.full( + size=( + batch_size, + max_nb_patches_h * max_nb_patches_w, + ), + fill_value=0, + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.linspace(0, 1 - 1 / nb_patches_h, + nb_patches_h) + fractional_coords_w = torch.linspace(0, 1 - 1 / nb_patches_w, + nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, + boundaries, + right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, + boundaries, + right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with +# CLIP->Siglip +class SiglipTextEmbeddings(nn.Module): + + def __init__(self, config: SiglipTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, + embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and + # exported when serialized + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).expand((1, -1)), + persistent=False) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[ + -1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`:" + f" {self.embed_dim} and `num_heads`: {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose( + 2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, + k_v_seq_len): + raise ValueError( + f"Attention weights should be of size " + f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}") + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError(f"Attention mask should be of size " + f"{(batch_size, 1, q_len, k_v_seq_len)}, " + f"but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, + dtype=torch.float32).to( + query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, + p=self.dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, + self.head_dim): + raise ValueError( + f"`attn_output` should be of size " + f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class SiglipFlashAttention2(SiglipAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as + the weights of the module stays untouched. The only required change would + be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any + of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False # Hack to make sure we don't use a causal mask + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length( + kv_seq_len, self.layer_idx) + + # TODO: These transpose are quite inefficient but Flash Attention + # requires the layout [batch_size, sequence_length, num_heads, + # head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training + # stability reasons therefore the input hidden states gets silently + # casted in float32. Hence, we need cast them back in the correct + # dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to + # not cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + "The input hidden states seems to be silently casted in " + "float32, this might be related to the fact you have upcasted " + "embedding or layer norm layers in float32. We will cast " + f"back the input in {target_dtype}.") + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward(query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate) + + attn_output = attn_output.reshape(bsz, q_len, + self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + def _flash_attention_forward(self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None): + """ + Calls the forward method of Flash Attention - if the input hidden + states contain at least one padding token first unpad the input, + then computes the attention scores and pad the final attention + scores. + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size + `(batch_size, seq_len)` where 0 stands for the position + of padding tokens and 1 for the position of non-padding + tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / + sqrt(head_dim) + """ + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import pad_input # noqa + + # TODO: Remove the `query_length != 1` check once Flash Attention for + # RoCm is bumped to 2.1. For details, please see the comment in + # LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, \ + max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, + query_length) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, + query_length) + else: + attn_output = flash_attn_func(query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, + query_length): + from flash_attn.bert_padding import index_first_axis, unpad_input + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, + head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = \ + unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip +class SiglipMLP(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with +# CLIP->Siglip +class SiglipEncoderLayer(nn.Module): + + def __init__(self, config: SiglipConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = (SiglipAttention(config) if + not getattr(config, "_flash_attn_2_enabled", False) + else SiglipFlashAttention2(config)) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where + padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all + attention layers. See `attentions` under returned tensors for + more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (attn_weights, ) + + return outputs + + +class SiglipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface + for downloading and loading pretrained models. + """ + + config_class = SiglipConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + + if isinstance(module, SiglipVisionEmbeddings): + width = (self.config.vision_config.hidden_size if isinstance( + self.config, SiglipConfig) else self.config.hidden_size) + nn.init.normal_(module.position_embedding.weight, + std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.normal_(module.q_proj.weight) + nn.init.normal_(module.k_proj.weight) + nn.init.normal_(module.v_proj.weight) + nn.init.normal_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.normal_(module.fc1.weight) + nn.init.normal_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, SiglipMultiheadAttentionPoolingHead): + nn.init.normal_(module.probe.data) + nn.init.normal_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, SiglipModel): + logit_scale_init = torch.tensor(0.0) + module.logit_scale.data.fill_(logit_scale_init) + module.logit_bias.data.zero_() + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +SIGLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass + documentation for the generic methods the library implements for all + its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/ + stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation + for all matter related to general usage and behavior. + Parameters: + config ([`SiglipConfig`]): Model configuration class with all the + parameters of the model. + Initializing with a config file does not load the weights + associated with the model, only the configuration. Check out + the [`~PreTrainedModel.from_pretrained`] method to load the + model weights. +""" + +SIGLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length) + `): + Indices of input sequence tokens in the vocabulary. Padding will + be ignored by default should you provide it. + Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] + for details. [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, + sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask + values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, + sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position + embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention + layers. See `attentions` under returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a + plain tuple. +""" + +SIGLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, + num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you + provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] + for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention + layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a + plain tuple. +""" + +SIGLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, + sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding + will be ignored by default should you provide it. + Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] + for details. [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)` + , *optional*): + Mask to avoid performing attention on padding token indices. Mask + values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, + sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position + embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, + num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you + provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] + for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention + layers. See `attentions` under returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a + plain tuple. +""" + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with +# CLIP->Siglip +class SiglipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` + self attention layers. Each layer is a [`SiglipEncoderLayer`]. + Args: + config: SiglipConfig + """ + + def __init__(self, config: SiglipConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. + This is useful if you want more control over how to convert + `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, + sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. + Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all + attention layers. See `attentions` under returned tensors for + more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a + plain tuple. + """ + output_attentions = output_attentions if output_attentions \ + is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else \ + self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1], ) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + + if not return_dict: + return tuple( + v for v in [hidden_states, encoder_states, all_attentions] + if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions) + + +class SiglipTextTransformer(nn.Module): + + def __init__(self, config: SiglipTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = SiglipTextEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + + self.head = nn.Linear(embed_dim, embed_dim) + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, + config_class=SiglipTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions \ + is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states \ + is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else \ + self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, + position_ids=position_ids) + + # note: SigLIP's text model does not use a causal mask, unlike the + # original CLIP model. + # expand attention_mask + if attention_mask is not None: + # [batch_size, seq_len] -> + # [batch_size, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask( + attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # Assuming "sticky" EOS tokenization, last token is always EOS. + pooled_output = last_hidden_state[:, -1, :] + pooled_output = self.head(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The text model from SigLIP without any head or projection on top.""", + SIGLIP_START_DOCSTRING, +) +class SiglipTextModel(SiglipPreTrainedModel): + config_class = SiglipTextConfig + + _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"] + + def __init__(self, config: SiglipTextConfig): + super().__init__(config) + self.text_model = SiglipTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, + config_class=SiglipTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + Examples: + ```python + >>> from transformers import AutoTokenizer, SiglipTextModel + >>> model = SiglipTextModel. + from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer. + from_pretrained("google/siglip-base-patch16-224") + >>> # important: make sure to set padding="max_length" + as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], + padding="max_length", return_tensors="pt") + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) + states + ```""" + return_dict = return_dict if return_dict is not None else \ + self.config.use_return_dict + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SiglipVisionTransformer(nn.Module): + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + self.head = SiglipMultiheadAttentionPoolingHead(config) + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, + config_class=SiglipVisionConfig) + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None\ + else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None \ + else self.config.use_return_dict + + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_attention_mask = torch.ones( + size=( + batch_size, + pixel_values.size(2) // self.config.patch_size, + pixel_values.size(3) // self.config.patch_size, + ), + dtype=torch.bool, + device=pixel_values.device, + ) + + hidden_states = self.embeddings( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending + # to the whole sequence), avoiding passing the attention_mask, which + # is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + attention_mask = None + else: + attention_mask = (_prepare_4d_attention_mask( + patch_attention_mask, hidden_states.dtype) + if not self.config._flash_attn_2_enabled else + patch_attention_mask) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = self.head( + hidden_state=last_hidden_state, + attention_mask=patch_attention_mask, + ) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention( + config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward(self, hidden_state, attention_mask): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(query=probe, + key=hidden_state, + value=hidden_state, + key_padding_mask=~attention_mask)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +@add_start_docstrings( + """The vision model from SigLIP without any head or projection on top.""", + SIGLIP_START_DOCSTRING, +) +class SiglipVisionModel(SiglipPreTrainedModel): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + + self.vision_model = SiglipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, + config_class=SiglipVisionConfig) + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, SiglipVisionModel + >>> model = SiglipVisionModel.from_pretrained( + "google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained( + "google/siglip-base-patch16-224") + >>> url = + "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + return_dict = return_dict if return_dict is not None \ + else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings(SIGLIP_START_DOCSTRING) +class SiglipModel(SiglipPreTrainedModel): + config_class = SiglipConfig + + def __init__(self, config: SiglipConfig): + super().__init__(config) + + if not isinstance(config.text_config, SiglipTextConfig): + raise ValueError("config.text_config is expected to be of type " + f"SiglipTextConfig but is of type" + f" {type(config.text_config)}.") + + if not isinstance(config.vision_config, SiglipVisionConfig): + raise ValueError("config.vision_config is expected to be of type " + "SiglipVisionConfig but is of type" + f" {type(config.vision_config)}.") + + text_config = config.text_config + vision_config = config.vision_config + + self.text_model = SiglipTextTransformer(text_config) + self.vision_model = SiglipVisionTransformer(vision_config) + + self.logit_scale = nn.Parameter(torch.randn(1)) + self.logit_bias = nn.Parameter(torch.randn(1)) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, + output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output + of [`SiglipTextModel`]. + Examples: + ```python + >>> from transformers import AutoTokenizer, AutoModel + >>> import torch + >>> model = AutoModel.from_pretrained( + "google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained( + "google/siglip-base-patch16-224") + >>> # important: make sure to set padding="max_length" as that's + how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], + padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... text_features = model.get_text_features(**inputs) + ```""" + # Use SigLIP model's config for some fields (if specified) instead + # of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None\ + else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None \ + else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, + output_dim`): The image embeddings obtained by applying the + projection layer to the pooled output of [`SiglipVisionModel`]. + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained( + "google/siglip-base-patch16-224") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + >>> with torch.no_grad(): + ... image_features = model.get_image_features(**inputs) + ```""" + # Use SiglipModel's config for some fields (if specified) instead + # of those of vision & text components. + output_attentions = output_attentions if output_attentions \ + is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else \ + self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SiglipOutput, + config_class=SiglipConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SiglipOutput]: + r""" + Returns: + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained( + "google/siglip-base-patch16-224") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] + >>> # important: we pass `padding=max_length` since the model was + trained with this + >>> inputs = processor(text=texts, images=image, + padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image + >>> probs = torch.sigmoid(logits_per_image) # these are the + probabilities + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 31.9% that image 0 is 'a photo of 2 cats' + ```""" + # Use SigLIP model's config for some fields (if specified) instead of + # those of vision & text components. + output_attentions = output_attentions if output_attentions \ + is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else \ + self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + text_embeds = text_outputs[1] + + # normalized features + image_embeds = image_embeds / image_embeds.norm( + p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = torch.matmul(text_embeds, image_embeds.t( + )) * self.logit_scale.exp() + self.logit_bias + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + raise NotImplementedError("SigLIP loss to be implemented") + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, + image_embeds, text_outputs, vision_outputs) + return ((loss, ) + output) if loss is not None else output + + return SiglipOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +def get_siglip_vision_model(_flash_attn_2_enabled=True, **kwargs): + siglip_vision_config = { + "hidden_size": 1152, + "image_size": 448, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14, + } + + # Detect attention implementation. + attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + if attn_backend != _Backend.FLASH_ATTN: + _flash_attn_2_enabled = False + + model_config = SiglipVisionConfig( + **siglip_vision_config, + _flash_attn_2_enabled=_flash_attn_2_enabled, + **kwargs) + + vision_model = SiglipVisionModel(model_config).vision_model + + return vision_model