diff --git a/aphrodite/common/sequence.py b/aphrodite/common/sequence.py index 60e0116d86..a497b5f639 100644 --- a/aphrodite/common/sequence.py +++ b/aphrodite/common/sequence.py @@ -142,6 +142,7 @@ def __init__( prompt: str, prompt_token_ids: List[int], block_size: int, + is_encoder_decoder: bool, lora_request: Optional[LoRARequest] = None, ) -> None: self.seq_id = seq_id @@ -154,8 +155,20 @@ def __init__( self.output_text = "" self.logical_token_blocks: List[LogicalTokenBlock] = [] + initial_token_ids = prompt_token_ids + if is_encoder_decoder: + # We need to separate the prompt and generated tokens for + # encoder-decoder models. + num_prompt_blocks = (len(prompt_token_ids) + block_size - + 1) // block_size + padded_prompt_len = num_prompt_blocks * block_size + initial_token_ids = prompt_token_ids + [0] * ( + padded_prompt_len - len(prompt_token_ids)) + # Also need to append decoder_start_token_id + initial_token_ids.append(0) + # Initialize the logical token blocks with the prompt token ids. - self._append_tokens_to_blocks(prompt_token_ids) + self._append_tokens_to_blocks(initial_token_ids) self.status = SequenceStatus.WAITING # Used for incremental detokenization diff --git a/aphrodite/endpoints/llm.py b/aphrodite/endpoints/llm.py index 18cdb3baa4..b48444666b 100644 --- a/aphrodite/endpoints/llm.py +++ b/aphrodite/endpoints/llm.py @@ -156,6 +156,10 @@ def generate( if sampling_params is None: # Use default sampling params. sampling_params = SamplingParams() + + if self.llm_engine.is_encoder_decoder: + assert (self.llm_engine.cache_config.context_shift is None + ), "Encoder-decoder models do not support context shift." # Add requests to the engine. num_requests = len(prompts) if prompts is not None else len( diff --git a/aphrodite/engine/aphrodite_engine.py b/aphrodite/engine/aphrodite_engine.py index b811bfeecc..fc3041dc09 100644 --- a/aphrodite/engine/aphrodite_engine.py +++ b/aphrodite/engine/aphrodite_engine.py @@ -125,6 +125,9 @@ def __init__( self.stat_logger = StatLogger( local_interval=_LOCAL_LOGGING_INTERVAL_SEC, labels=dict(model_name=model_config.model)) + + self.is_encoder_decoder = getattr(self.model_config.hf_config, + "is_encoder_decoder", False) def get_tokenizer_for_seq(self, sequence: Sequence): return self.tokenizer.get_lora_tokenizer(sequence.lora_request) @@ -462,7 +465,7 @@ def add_request( block_size = self.cache_config.block_size seq_id = next(self.seq_counter) seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - lora_request) + self.is_encoder_decoder, lora_request) # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, diff --git a/aphrodite/modeling/layers/attention.py b/aphrodite/modeling/layers/attention.py index 89e981d892..d7d89b401e 100644 --- a/aphrodite/modeling/layers/attention.py +++ b/aphrodite/modeling/layers/attention.py @@ -226,17 +226,12 @@ def forward( ) else: - # Decoding run. - output = _paged_attention( - query, - key_cache, - value_cache, - input_metadata, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - kv_quant_param, - ) + # Decoding run + output = paged_attention( + query, key_cache, value_cache, input_metadata.block_tables, + input_metadata.context_lens, input_metadata.max_context_len, + self.num_kv_heads, self.scale, self.alibi_slopes, + kv_quant_param, None, input_metadata.kv_cache_dtype) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) @@ -276,23 +271,26 @@ def _make_alibi_bias( return attn_bias -def _paged_attention( +def paged_attention( query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, - input_metadata: InputMetadata, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], + custom_bias: Optional[torch.Tensor], + kv_cache_dtype: torch.dtype, kv_quant_param: List[float], ) -> torch.Tensor: output = torch.empty_like(query) block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape - max_num_partitions = ( - (input_metadata.max_context_len + _PARTITION_SIZE - 1) // - _PARTITION_SIZE) + max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) # NOTE: We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -300,8 +298,8 @@ def _paged_attention( # to parallelize. # TODO: Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = input_metadata.max_context_len <= 8192 and ( - max_num_partitions == 1 or num_seqs * num_heads > 512) + use_v1 = max_context_len <= 8192 and (max_num_partitions == 1 + or num_seqs * num_heads > 512) if use_v1: # Run PagedAttention V1. ops.paged_attention_v1( @@ -311,12 +309,13 @@ def _paged_attention( value_cache, num_kv_heads, scale, - input_metadata.block_tables, - input_metadata.context_lens, + block_tables, + context_lens, block_size, - input_metadata.max_context_len, + max_context_len, alibi_slopes, - input_metadata.kv_cache_dtype, + custom_bias, + kv_cache_dtype, *kv_quant_param, ) else: @@ -343,12 +342,13 @@ def _paged_attention( value_cache, num_kv_heads, scale, - input_metadata.block_tables, - input_metadata.context_lens, + block_tables, + context_lens, block_size, - input_metadata.max_context_len, + max_context_len, alibi_slopes, - input_metadata.kv_cache_dtype, + custom_bias, + kv_cache_dtype, *kv_quant_param, ) return output diff --git a/aphrodite/modeling/layers/enc_dec_attention.py b/aphrodite/modeling/layers/enc_dec_attention.py new file mode 100644 index 0000000000..bca9a0aef3 --- /dev/null +++ b/aphrodite/modeling/layers/enc_dec_attention.py @@ -0,0 +1,240 @@ +"""Multi-head attention for encoder-decoder models.""" +from typing import Optional + +import torch +import torch.nn as nn +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import ( + BlockDiagonalCausalMask, ) + +from aphrodite._C import cache_ops +from aphrodite.modeling.metadata import InputMetadata +from aphrodite.common.utils import is_hip +from aphrodite.modeling.layers.attention import paged_attention + +_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] + + +class EncDecAttention(nn.Module): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + + if self.head_size not in _SUPPORTED_HEAD_SIZES: + raise ValueError(f"head_size ({self.head_size}) is not supported. " + f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") + + +class EncoderAttention(EncDecAttention): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + ) -> None: + super().__init__(num_heads, head_size, scale) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + """Encoder attention forward pass. + Args: + query: Query tensor. + key: Key tensor. + value: Value tensor. + custom_bias: Custom bias tensor. + Returns: + Output tensor. + """ + # query: [batch_size, seq_len, num_heads * head_size] + # key: [batch_size, seq_len, num_heads * head_size] + # value: [batch_size, seq_len, num_heads * head_size] + # custom_bias: [batch_size, seq_len, seq_len] + # output: [batch_size, seq_len, num_heads * head_size] + + assert input_metadata.is_prompt + batch_size, seq_len, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + key = key.view(batch_size, seq_len, self.num_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_heads, self.head_size) + if input_metadata.attn_bias is None: + input_metadata.attn_bias = BlockDiagonalCausalMask.from_seqlens( + [seq_len] * batch_size) + + input_metadata.attn_bias = input_metadata.attn_bias[:, :, :, :seq_len] + + # Normal attention + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, + ) + output = out.view(batch_size, seq_len, hidden_size) + return output + + +class DecoderAttention(EncDecAttention): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + ) -> None: + super().__init__(num_heads, head_size, scale) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + input_metadata: InputMetadata, + ): + """Decoder attention forward pass. + Args: + query: Query tensor. + key: Key tensor. + value: Value tensor. + key_cache: Key cache tensor. + value_cache: Value cache tensor. + custom_bias: Custom bias tensor. + Returns: + Output tensor. + """ + + batch_size, seq_len, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_heads, self.head_size) + value = value.view(-1, self.num_heads, self.head_size) + # Reshape the keys and values and store them in the cache. + # If key_cache and value_cache are not provided, the new key and value + # vectors will not be cached. This happens during the initial memory + # profiling run. + if key_cache is not None and value_cache is not None: + + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, + input_metadata.slot_mapping[:, -1].flatten().contiguous(), + input_metadata.kv_cache_dtype) + + max_prompt_len = input_metadata.prompt_lens.max().item() + block_size = value_cache.shape[3] + prompt_table_len = (max_prompt_len + block_size - 1) // block_size + block_tables = input_metadata.block_tables[:, + prompt_table_len:].contiguous( + ) + output = paged_attention( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_tables, + context_lens=input_metadata.context_lens, + max_context_len=input_metadata.max_context_len, + num_kv_heads=self.num_heads, + scale=self.scale, + alibi_slopes=None, + custom_bias=input_metadata.attn_bias.to(torch.float32), + kv_cache_dtype=input_metadata.kv_cache_dtype, + kv_quant_param=input_metadata.kv_quant_params, + ) + return output.view(batch_size, seq_len, hidden_size) + + +class CrossAttention(EncDecAttention): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + ) -> None: + super().__init__(num_heads, head_size, scale) + + def forward( + self, + query: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + input_metadata: InputMetadata, + ): + """Cross attention forward pass. + Args: + query: Query tensor. + key_cache: Key cache tensor. + value_cache: Value cache tensor. + input_metadata: Input metadata. + key: Key tensor. Only needed in the first pass. + value: Value tensor. Only needed in the first pass. + custom_bias: Custom bias tensor. + Returns: + Output tensor. + """ + batch_size, seq_len, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + key = key.view(-1, self.num_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_heads, self.head_size) + + # Reshape the keys and values and store them in the cache. + # It only happens during the first pass. + if (input_metadata.is_prompt and key_cache is not None + and value_cache is not None): + assert key is not None and value is not None + cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + input_metadata.slot_mapping[:, :-1].flatten().contiguous(), + input_metadata.kv_cache_dtype, + ) + + max_prompt_len = input_metadata.prompt_lens.int().max().item() + block_size = value_cache.shape[3] + prompt_table_len = (max_prompt_len + block_size - 1) // block_size + block_tables = input_metadata.block_tables[:, : + prompt_table_len].contiguous( + ) + + output = paged_attention( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_tables, + context_lens=input_metadata.prompt_lens.int(), + max_context_len=max_prompt_len, + num_kv_heads=self.num_heads, + scale=self.scale, + alibi_slopes=None, + custom_bias=None, + kv_cache_dtype=input_metadata.kv_cache_dtype, + kv_quant_param=input_metadata.kv_quant_params, + ) + + return output.view(batch_size, seq_len, hidden_size) diff --git a/aphrodite/modeling/metadata.py b/aphrodite/modeling/metadata.py index b954d76a5b..766dc87ae6 100644 --- a/aphrodite/modeling/metadata.py +++ b/aphrodite/modeling/metadata.py @@ -49,6 +49,7 @@ def __init__( def __repr__(self) -> str: return ("InputMetadata(" f"is_prompt={self.is_prompt}, " + f"prompt_lens={self.prompt_lens}, " f"max_context_len={self.max_context_len}, " f"slot_mapping={self.slot_mapping}, " f"context_lens={self.context_lens}, " diff --git a/aphrodite/modeling/models/__init__.py b/aphrodite/modeling/models/__init__.py index 091a20c334..9c22f3540c 100644 --- a/aphrodite/modeling/models/__init__.py +++ b/aphrodite/modeling/models/__init__.py @@ -44,6 +44,7 @@ "RWForCausalLM": ("falcon", "FalconForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), + "T5ForConditionalGeneration": ("t5", "T5ForConditionalGeneration"), } # Models not supported by ROCm. diff --git a/aphrodite/modeling/models/t5.py b/aphrodite/modeling/models/t5.py new file mode 100644 index 0000000000..39509360c9 --- /dev/null +++ b/aphrodite/modeling/models/t5.py @@ -0,0 +1,603 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The PygmalionAI team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only T5 model compatible with HuggingFace weights.""" +from typing import List, Optional, Tuple + +import math +import copy + +import torch +from torch import nn +from transformers import T5Config + +from aphrodite.modeling.metadata import InputMetadata +from aphrodite.modeling.layers.activation import get_act_fn +from aphrodite.modeling.layers.enc_dec_attention import ( + EncoderAttention, + DecoderAttention, + CrossAttention, +) +from aphrodite.modeling.layers.linear import ( + ColumnParallelLinear, + LinearMethodBase, + RowParallelLinear, +) +from aphrodite.modeling.layers.sampler import Sampler +from aphrodite.modeling.megatron.parallel_state import ( + get_tensor_model_parallel_world_size, ) +from aphrodite.modeling.sampling_metadata import SamplingMetadata +from aphrodite.modeling.hf_downloader import ( + default_weight_loader, + hf_model_weights_iterator, +) + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class T5LayerNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction + of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # T5 uses a layer_norm which only scales and doesn't shift, which is + # also known as Root Mean Square Layer Normalization + # https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that + # the accumulation for half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, + keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class T5DenseActDense(nn.Module): + + def __init__(self, config: T5Config): + super().__init__() + self.wi = ColumnParallelLinear(config.d_model, config.d_ff, bias=False) + self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False) + self.act = get_act_fn(config.dense_act_fn) + + def forward(self, hidden_states): + hidden_states, _ = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + + def __init__(self, config: T5Config): + super().__init__() + self.wi_0 = ColumnParallelLinear(config.d_model, + config.d_ff, + bias=False) + self.wi_1 = ColumnParallelLinear(config.d_model, + config.d_ff, + bias=False) + self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False) + self.act = get_act_fn(config.dense_act_fn) + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)[0]) + hidden_linear, _ = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states, _ = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + + def __init__(self, config: T5Config): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense(config) + else: + self.DenseReluDense = T5DenseActDense(config) + + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + forwarded_states + return hidden_states + + +class T5Attention(nn.Module): + + def __init__( + self, + config: T5Config, + is_cross: bool, + has_relative_attention_bias: bool, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + total_num_heads = config.num_heads + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( + ) + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.n_heads = total_num_heads // tensor_model_parallel_world_size + self.inner_dim = self.n_heads * self.key_value_proj_dim + + self.q = ColumnParallelLinear(self.d_model, self.inner_dim, bias=False) + self.k = ColumnParallelLinear(self.d_model, self.inner_dim, bias=False) + self.v = ColumnParallelLinear(self.d_model, self.inner_dim, bias=False) + self.o = RowParallelLinear(self.inner_dim, self.d_model, bias=False) + + if has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding( + self.relative_attention_num_buckets, self.n_heads) + + self.is_cross = is_cross + if self.is_decoder: + if self.is_cross: + self.attn = CrossAttention(self.n_heads, + self.key_value_proj_dim, 1) + else: + self.attn = DecoderAttention(self.n_heads, + self.key_value_proj_dim, 1) + else: + self.attn = EncoderAttention(self.n_heads, self.key_value_proj_dim, + 1) + + @staticmethod + def _relative_position_bucket(relative_position, + bidirectional=True, + num_buckets=32, + max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + Translate relative position to a bucket number for relative attention. + The relative position is defined as memory_position - query_position, + i.e. the distance in tokens from the attending position to the + attended-to position. If bidirectional=False, then positive relative + positions are invalid. We use smaller buckets for small absolute + relative_position and larger buckets for larger absolute + relative_positions. All relative positions >=max_distance map to the + same bucket. All relative positions <=-max_distance map to the same + bucket. This should allow for more graceful generalization to longer + sequences than the model has been trained on + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + Returns: + a Tensor with the same shape as relative_position, containing int32 + values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to( + torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, + torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in + # positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) / + math.log(max_distance / max_exact) * + (num_buckets - max_exact)).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + + relative_buckets += torch.where(is_small, relative_position, + relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, + dtype=torch.long, + device="cuda")[:, None] + memory_position = torch.arange(key_length, + dtype=torch.long, + device="cuda")[None, :] + relative_position = (memory_position - context_position + ) # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + # shape (query_length, key_length, num_heads) + values = self.relative_attention_bias(relative_position_bucket) + # shape (1, num_heads, query_length, key_length) + values = values.permute([2, 0, 1]).unsqueeze(0) + return values + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: Optional[KVCache], + input_metadata: InputMetadata, + encoder_hidden_states: Optional[torch.Tensor], + ) -> torch.Tensor: + q, _ = self.q(hidden_states) + + batch_size = hidden_states.shape[0] + seq_len = hidden_states.shape[1] + prompt_len = input_metadata.prompt_lens.max().item() + context_len = input_metadata.context_lens.max().item() + context_len = max(context_len, 1) + + block_size = 16 + + if not self.is_decoder: + assert kv_cache is None + # Encoder self attention, no cache operations + k, _ = self.k(hidden_states) + v, _ = self.v(hidden_states) + + if input_metadata.attn_bias is None: + input_metadata.attn_bias = self.compute_bias( + prompt_len, (prompt_len + block_size - 1) // block_size * + block_size).repeat(batch_size, 1, 1, 1) + for i in range(batch_size): + input_metadata.attn_bias[ + i, :, :, + input_metadata.prompt_lens[i]:, ] = torch.finfo( + input_metadata.attn_bias.dtype).min + + attn_output = self.attn(q, k, v, input_metadata) + + elif not self.is_cross: + # Decoder self attention + k, _ = self.k(hidden_states) + v, _ = self.v(hidden_states) + + if input_metadata.attn_bias is None: + position_bias = self.compute_bias( + 1 if input_metadata.is_prompt else context_len, + (context_len + block_size - 1) // block_size * + block_size).repeat(batch_size, 1, 1, 1) + input_metadata.attn_bias = position_bias[:, :, + -seq_len:, :].contiguous( + ) + + key_cache, value_cache = kv_cache + + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) + + else: + # Cross attention + + key_cache, value_cache = kv_cache + if input_metadata.is_prompt: + assert encoder_hidden_states is not None + k, _ = self.k(encoder_hidden_states) + v, _ = self.v(encoder_hidden_states) + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) + else: + attn_output = self.attn(q, None, None, key_cache, value_cache, + input_metadata) + + attn_output, _ = self.o(attn_output) + return attn_output + + +class T5LayerSelfAttention(nn.Module): + + def __init__( + self, + config, + has_relative_attention_bias, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.SelfAttention = T5Attention( + config, + is_cross=False, + has_relative_attention_bias=has_relative_attention_bias, + linear_method=linear_method, + ) + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + hidden_states=normed_hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + encoder_hidden_states=None, + ) + hidden_states = hidden_states + attention_output + return hidden_states + + +class T5LayerCrossAttention(nn.Module): + + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.EncDecAttention = T5Attention( + config, + is_cross=True, + has_relative_attention_bias=False, + linear_method=linear_method, + ) + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: Optional[KVCache], + input_metadata: InputMetadata, + encoder_hidden_states: Optional[torch.Tensor], + ) -> torch.Tensor: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + hidden_states=normed_hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + hidden_states = hidden_states + attention_output + return hidden_states + + +class T5Block(nn.Module): + + def __init__( + self, + config, + has_relative_attention_bias: bool, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + T5LayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias, + linear_method=linear_method, + )) + if self.is_decoder: + self.layer.append( + T5LayerCrossAttention(config, linear_method=linear_method)) + + self.layer.append(T5LayerFF(config)) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: Optional[KVCache], + input_metadata: InputMetadata, + encoder_hidden_states: Optional[torch.Tensor], + ): + hidden_states = self.layer[0]( + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) + + if self.is_decoder: + hidden_states = self.layer[1]( + hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + return hidden_states + + +class T5Stack(nn.Module): + + def __init__( + self, + config: T5Config, + embed_tokens: torch.Tensor, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.embed_tokens = embed_tokens + + self.block = nn.ModuleList([ + T5Block( + config, + has_relative_attention_bias=(i == 0), + linear_method=linear_method, + ) for i in range(config.num_layers) + ]) + + self.final_layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + encoder_hidden_states: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + for i, layer_module in enumerate(self.block): + kv_cache = kv_caches[i] if self.is_decoder else None + + layer_outputs = layer_module( + hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = layer_outputs + + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class T5ForConditionalGeneration(nn.Module): + + def __init__(self, + config: T5Config, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.config = config + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + self.encoder = T5Stack(encoder_config, self.shared, linear_method) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + self.decoder = T5Stack(decoder_config, self.shared, linear_method) + + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + if input_metadata.is_prompt: + # prompt run, need to run encoder once + hidden_states = self.encoder(input_ids, kv_caches, input_metadata, + None) + # Clear the attention bias + input_metadata.attn_bias = None + batch_size = input_ids.shape[0] + input_ids = (torch.ones(batch_size, 1, dtype=torch.long) * + self.config.decoder_start_token_id).cuda() + + else: + hidden_states = None + + if kv_caches[0][0] is not None: # Skip decoder for profiling run + hidden_states = self.decoder(input_ids, kv_caches, input_metadata, + hidden_states) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + hidden_states = hidden_states * (self.model_dim**-0.5) + + return hidden_states + + def sample(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata): + next_tokens = self.sampler(self.shared.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "EncDecAttention.relative_attention_bias" in name: + continue + + assert name in params_dict, f"{name} not in params_dict" + param = params_dict[name] + assert param.shape == loaded_weight.shape, ( + f"{name} shape mismatch between model and checkpoint: " + f"{param.shape} != {loaded_weight.shape}") + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + diff --git a/aphrodite/task_handler/model_runner.py b/aphrodite/task_handler/model_runner.py index 26f08edde8..83f97fc681 100644 --- a/aphrodite/task_handler/model_runner.py +++ b/aphrodite/task_handler/model_runner.py @@ -92,6 +92,11 @@ def __init__( self.kv_quant_params = (self.load_kv_quant_params( model_config, kv_quant_params_path) if self.kv_cache_dtype == "int8" else None) + + # Unpack HF is_encoder_decoder config attribute + self.is_encoder_decoder = False if self.model_config is None else \ + getattr(self.model_config.hf_config, + "is_encoder_decoder", False) def load_kv_quant_params(self, model_config: ModelConfig, kv_quant_params_path: str) -> List[List[float]]: @@ -175,6 +180,8 @@ def _prepare_prompt( context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] + block_tables: List[List[int]] = [] + max_block_table_len = 0 for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -248,6 +255,10 @@ def _prepare_prompt( block_offset = i % self.block_size slot = block_number * self.block_size + block_offset slot_mapping[-1].append(slot) + + if self.is_encoder_decoder: + block_tables.append(block_table) + max_block_table_len = max(max_block_table_len, len(block_table)) max_prompt_len = max(subquery_lens) input_tokens = _make_tensor_with_pad( @@ -264,9 +275,17 @@ def _prepare_prompt( dtype=torch.long, device=self.device, ) + if self.is_encoder_decoder and len(block_tables) > 0: + # Pad the slot mapping to the same length and add decoder_start_id + for i in range(len(slot_mapping)): + slot_mapping[i] += [_PAD_SLOT_ID + ] * (max_prompt_len - len(slot_mapping[i])) + slot_mapping[i].append(block_tables[i][-1] * self.block_size) + + max_slot_mapping_len = max_prompt_len + self.is_encoder_decoder slot_mapping = _make_tensor_with_pad( slot_mapping, - max_prompt_len, + max_slot_mapping_len, pad=_PAD_SLOT_ID, dtype=torch.long, device=self.device, @@ -280,13 +299,31 @@ def _prepare_prompt( device=self.device) # Prepare prefix block tables max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) - block_tables = _make_tensor_with_pad( - prefix_block_tables, - max_len=max_prompt_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) + if self.is_encoder_decoder: + padded_block_tables = [] + # Pad the encoder block tables to the same length and then add a + # decoder block table in the end + for block_table in block_tables: + block_table = block_table[:-1] + [0] * ( + max_block_table_len - len(block_table)) + block_table[-1:] + padded_block_tables.append(block_table) + + block_tables_tensor = _make_tensor_with_pad( + padded_block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device) + else: + # Prepare prefix block tables + max_prompt_block_table_len = max( + len(t) for t in prefix_block_tables) + block_tables_tensor = _make_tensor_with_pad( + prefix_block_tables, + max_len=max_prompt_block_table_len, + pad=0, + dtype=torch.int, + device=self.device) start_loc_tensor = torch.arange( 0, len(prompt_lens) * max_prompt_len, @@ -304,9 +341,9 @@ def _prepare_prompt( prompt_lens=prompt_lens_tensor, max_seq_len=max_prompt_len, start_loc=start_loc_tensor, - max_context_len=None, + max_context_len=max(context_lens), context_lens=context_lens_tensor, - block_tables=block_tables, + block_tables=block_tables_tensor, use_cuda_graph=False, kv_cache_dtype=self.kv_cache_dtype, kv_quant_params=self.kv_quant_params, @@ -331,12 +368,14 @@ def _prepare_decode( input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] + prompt_lens: List[int] = [] context_lens: List[int] = [] block_tables: List[List[int]] = [] lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + max_block_table_len = 0 for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt @@ -355,11 +394,28 @@ def _prepare_decode( position = seq_len - 1 input_positions.append([position]) - context_len = (seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window)) + prompt_len = len(seq_data.prompt_token_ids) + prompt_lens.append(prompt_len) + + if self.is_encoder_decoder: + # Encoder-decoder model stores prompt and generation tokens + # separately, so we need to adjust to the pad. + prompt_blocks_num = (prompt_len + self.block_size - + 1) // self.block_size + prompt_pad = prompt_blocks_num * self.block_size - prompt_len + position += prompt_pad + 1 # One extra for decoder_start_id + + if self.is_encoder_decoder: + context_len = seq_len - prompt_len + 1 + elif self.sliding_window is not None: + context_len = min(seq_len, self.sliding_window) + else: + context_len = seq_len context_lens.append(context_len) block_table = seq_group_metadata.block_tables[seq_id] + max_block_table_len = max(max_block_table_len, + len(block_table)) block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset @@ -372,6 +428,16 @@ def _prepare_decode( self.block_size) block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + if self.is_encoder_decoder: + padded_block_tables = [] + # Pad the encoder block tables to the same length and then add a + # decoder block table in the end + for block_table in block_tables: + block_table = block_table[:-1] + [0] * ( + max_block_table_len - len(block_table)) + block_table[-1:] + padded_block_tables.append(block_table) + + block_tables = padded_block_tables batch_size = len(input_tokens) max_context_len = max(context_lens) @@ -379,6 +445,10 @@ def _prepare_decode( not self.model_config.enforce_eager and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and max_context_len <= self.max_context_len_to_capture) + + prompt_lens = torch.tensor(prompt_lens, + dtype=torch.int, + device=self.device) if use_captured_graph: # Pad the input tokens, positions, and slot mapping to match the # batch size of the captured graph. @@ -441,7 +511,7 @@ def _prepare_decode( input_metadata = InputMetadata( is_prompt=False, slot_mapping=slot_mapping, - prompt_lens=None, + prompt_lens=prompt_lens, max_seq_len=None, start_loc=None, max_context_len=max_context_len, @@ -479,7 +549,7 @@ def _prepare_sample( sampling_params = seq_group_metadata.sampling_params seq_groups.append((seq_ids, sampling_params)) - if seq_group_metadata.is_prompt: + if seq_group_metadata.is_prompt and not self.is_encoder_decoder: assert len(seq_ids) == 1 assert subquery_lens is not None subquery_len = subquery_lens[i] diff --git a/kernels/attention/attention_kernels.cu b/kernels/attention/attention_kernels.cu index 00c14a21ae..3054a83576 100644 --- a/kernels/attention/attention_kernels.cu +++ b/kernels/attention/attention_kernels.cu @@ -23,6 +23,7 @@ #include #include #include +#include #include "attention_dtypes.h" #include "attention_utils.cuh" @@ -110,6 +111,7 @@ __device__ void paged_attention_kernel( const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, max_seq_len] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -154,6 +156,10 @@ __device__ void paged_attention_kernel( const int num_queries_per_kv = num_heads / num_kv_heads; const int kv_head_idx = head_idx / num_queries_per_kv; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + const float* custom_bias_vec = custom_bias == nullptr + ? nullptr + : custom_bias + seq_idx * num_kv_heads * num_context_blocks * BLOCK_SIZE + + kv_head_idx * num_context_blocks * BLOCK_SIZE; // A vector type to store a part of a key or a query. // The vector size is configured in such a way that the threads in a thread group @@ -246,8 +252,10 @@ __device__ void paged_attention_kernel( // Compute dot product. // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); - // Add the ALiBi bias if slopes are given. - qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + // Add the custom or ALiBi bias if given. + qk += (custom_bias_vec != nullptr) ? custom_bias_vec[token_idx] + : (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) + : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. @@ -458,6 +466,7 @@ __global__ void paged_attention_v1_kernel( const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, seq_len] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -468,7 +477,7 @@ __global__ void paged_attention_v1_kernel( paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp); + max_num_blocks_per_seq, alibi_slopes, custom_bias, q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp); } // Grid: (num_heads, num_seqs, max_num_partitions). @@ -493,6 +502,7 @@ __global__ void paged_attention_v2_kernel( const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, seq_len] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -502,7 +512,7 @@ __global__ void paged_attention_v2_kernel( const float v_zp) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, + block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, custom_bias, q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp); } @@ -623,6 +633,7 @@ __global__ void paged_attention_v2_reduce_kernel( context_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ + custom_bias_ptr, \ q_stride, \ kv_block_stride, \ kv_head_stride, \ @@ -649,6 +660,7 @@ void paged_attention_v1_launcher( torch::Tensor& context_lens, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const float k_scale, const float k_zp, const float v_scale, @@ -665,9 +677,10 @@ void paged_attention_v1_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; + + const float* custom_bias_ptr = custom_bias ? reinterpret_cast(custom_bias.value().data_ptr()) : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); @@ -728,6 +741,7 @@ void paged_attention_v1_launcher( context_lens, \ max_context_len, \ alibi_slopes, \ + custom_bias, \ k_scale, \ k_zp, \ v_scale, \ @@ -763,6 +777,7 @@ void paged_attention_v1( int block_size, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const std::string& kv_cache_dtype, const float k_scale = 1.0f, const float k_zp = 0.0f, @@ -821,6 +836,7 @@ void paged_attention_v1( context_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ + custom_bias_ptr, \ q_stride, \ kv_block_stride, \ kv_head_stride, \ @@ -858,6 +874,7 @@ void paged_attention_v2_launcher( torch::Tensor& context_lens, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const float k_scale, const float k_zp, const float v_scale, @@ -874,9 +891,10 @@ void paged_attention_v2_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; + + const float* custom_bias_ptr = custom_bias ? reinterpret_cast(custom_bias.value().data_ptr()) : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -946,6 +964,7 @@ void paged_attention_v2_launcher( context_lens, \ max_context_len, \ alibi_slopes, \ + custom_bias, \ k_scale, \ k_zp, \ v_scale, \ @@ -984,6 +1003,7 @@ void paged_attention_v2( int block_size, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const std::string& kv_cache_dtype, const float k_scale = 1.0f, const float k_zp = 0.0f, diff --git a/kernels/ops.h b/kernels/ops.h index 634e155c33..b1769c4cf8 100644 --- a/kernels/ops.h +++ b/kernels/ops.h @@ -14,6 +14,7 @@ void paged_attention_v1( int block_size, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const std::string& kv_cache_dtype, float k_scale = 1.0f, float k_zp = 0.0f, @@ -35,6 +36,7 @@ void paged_attention_v2( int block_size, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const std::string& kv_cache_dtype, float k_scale = 1.0f, float k_zp = 0.0f,