From 55c7a22ca66ff2206976070e6e3c3a79bb965732 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Wed, 7 Feb 2024 03:43:19 +0000 Subject: [PATCH 1/8] add t5 modeling code --- aphrodite/modeling/models/t5.py | 577 ++++++++++++++++++++++++++++++++ 1 file changed, 577 insertions(+) create mode 100644 aphrodite/modeling/models/t5.py diff --git a/aphrodite/modeling/models/t5.py b/aphrodite/modeling/models/t5.py new file mode 100644 index 0000000000..35f3a52b63 --- /dev/null +++ b/aphrodite/modeling/models/t5.py @@ -0,0 +1,577 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The PygmalionAI team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only T5 model compatible with HuggingFace weights.""" +from typing import Any, Dict, List, Optional, Tuple +import math +import copy + +import torch +from torch import nn +from transformers import T5Config + +from aphrodite.modeling.metadata import InputMetadata +from aphrodite.modeling.layers.activation import get_act_fn +from aphrodite.modeling.layers.attention import PagedAttention +from aphrodite.modeling.layers.sampler import Sampler +from aphrodite.modeling.sampling_metadata import SamplingMetadata +from aphrodite.modeling.hf_downloader import (hf_model_weights_iterator, + convert_pyslice_to_tensor) +from aphrodite.common.sequence import SamplerOutput + +from aphrodite.common.logger import init_logger + +KVCache = Tuple[torch.Tensor, torch.Tensor] + +logger = init_logger(__name__) + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states / torch.sqrt(variance + self.variance_epsilon) + + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class T5DenseAct(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.act = get_act_fn(config.dense_act_fn) + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedAct(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.act = get_act_fn(config.dense_act_fn) + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedAct(config) + else: + self.DenseReluDense = T5DenseAct(config) + + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + forwarded_states + return hidden_states + + +class T5Attention(nn.Module): + def __init__( + self, + config: T5Config, + has_relative_attention_bias: bool, + is_cross: bool): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.is_cross = is_cross + + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.has_relative_attention_bias = nn.Embedding( + self.relative_attention_num_buckets, self.n_heads) + + self.paged_attn = PagedAttention( + self.n_heads, self.key_value_proj_dim, scale=1) + + + @staticmethod + def _relative_position_bucket( + relative_position, + bidirectional=True, + num_buckets=32, + max_distance=128): + """Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. + The relative position is defined as memory_position - query_position, i.e. + the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are + invalid. We use smaller buckets for small absolute relative_position and + larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions + <=-max_distance map to the same bucket. This should allow for more graceful + generalization to longer sequences than the model has been trained on. + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 + values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > + 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = - \ + torch.min(relative_position, + torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like( + relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, + relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = torch.arange( + query_length, dtype=torch.long, device="cuda")[:, None] + memory_position = torch.arange( + key_length, dtype=torch.long, device="cuda")[None, :] + relative_position = memory_position - \ + context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + # shape (query_length, key_length, num_heads) + values = self.relative_attention_bias(relative_position_bucket) + # shape (1, num_heads, query_length, key_length) + values = values.permute([2, 0, 1]).unsqueeze(0) + return values + + def forward( + self, + hidden_states: torch.Tensor, + position_bias: Optional[torch.Tensor], + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + q = self.q(hidden_states) + batch_size = hidden_states.shape[0] + + if not self.is_decoder: + # Encoder self-attn + + def shape(states): + """Projection.""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """Reshape.""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + q = shape(q) + k = shape(self.k(hidden_states)) + v = shape(self.v(hidden_states)) + + if position_bias is None: + assert self.has_relative_attention_bias + position_bias = self.compute_bias( + hidden_states.shape[1], hidden_states.shape[1]) + + scores = torch.matmul(q, k.transpose(-1, -2)) + scores = scores + position_bias + + attn_weights = nn.functional.softmax( + scores.float(), + dim=-1).type_as(scores) + attn_output = unshape(torch.matmul(attn_weights, v)) + attn_output = self.o(attn_output) + + elif not self.is_cross: + # Decoder self-attn + k = self.k(hidden_states) + v = self.v(hidden_states) + + if position_bias is None: + assert self.has_relative_attention_bias + position_bias = self.compute_bias( + input_metadata.max_context_len, input_metadata.max_context_len) + + key_cache, value_cache = kv_cache + + attn_output = self.paged_attn(q, k, v, key_cache, value_cache, + input_metadata, cache_event) + attn_output = self.o(attn_output) + + else: + # Decoder cross-attn + assert position_bias is None + assert self.has_relative_attention_bias == False + + k, v = kv_cache + scores = torch.matmul(q, k.transpose(-1, -2)) + attn_weights = nn.functional.softmax( + scores.float(), + dim=-1).type_as(scores) + attn_output = torch.matmul(attn_weights, v) + attn_output = self.o(attn_output) + + return (attn_output,) + (position_bias,) + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias): + super().__init__() + self.SelfAttention = T5Attention( + config, + has_relative_attention_bias=has_relative_attention_bias, + is_cross=False) + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + position_bias: Optional[torch.Tensor], + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + hidden_states=normed_hidden_states, + position_bias=position_bias, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event) + hidden_states = hidden_states + attention_output[0] + return (hidden_states,) + (attention_output[1],) + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = T5Attention( + config, + has_relative_attention_bias=False, + is_cross=True) + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + position_bias: Optional[torch.Tensor], + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + hidden_states=normed_hidden_states, + position_bias=position_bias, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event) + hidden_states = hidden_states + attention_output[0] + return (hidden_states,) + (attention_output[1],) + + +class T5Block(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append(T5LayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias)) + + if self.is_decoder: + self.layer.append(T5LayerCrossAttention(config)) + self.layer.append(T5LayerFF(config)) + + def forward( + self, + hidden_states: torch.Tensor, + position_bias: Optional[torch.Tensor], + kv_cache: KVCache, + cross_attention_kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ): + + self_attention_outputs = self.layer[0]( + hidden_states=hidden_states, + position_bias=position_bias, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event) + + hidden_states = self_attention_outputs[0] + self_attention_bias = self_attention_outputs[1] + + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value) + + if self.is_decoder: + cross_attention_outputs = self.layer[1]( + hidden_states=hidden_states, + kv_cache=cross_attention_kv_cache, + position_bias=None, + input_metadata=input_metadata, + cache_event=cache_event) + hidden_states = cross_attention_outputs[0] + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value) + + # Apply FF layer + hidden_states = self.layer[-1](hidden_states) + outputs = (hidden_states,) + (self_attention_bias,) + + return outputs + + +class T5Stack(nn.Module): + def __init__( + self, + config: T5Config, + embed_tokens: torch.Tensor): + + super().__init__() + self.is_decoder = config.is_decoder + self.embed_tokens = embed_tokens + + self.block = nn.ModuleList([ + T5Block(config, has_relative_attention_bias=bool(i == 0)) + for i in range(config.num_layers)]) + + self.final_layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + kv_caches: List[KVCache], + cross_attention_kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + + hidden_states = self.embed_tokens(input_ids) + position_bias = None + + for i, layer_module in enumerate(self.block): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + + kv_cache = kv_caches[i] if self.is_decoder else None + cross_attention_kv_cache = cross_attention_kv_caches[i] if self.is_decoder else None + + layer_outputs = layer_module( + hidden_states, + position_bias=position_bias, + kv_cache=kv_cache, + cross_attention_kv_cache=cross_attention_kv_cache, + input_metadata=input_metadata, + cache_event=cache_event) + + hidden_states = layer_outputs[0] + + # We share the position biases between the layers + # The first layer to store them + # layer_outputs = hidden_states, (self-attention position bias,) + position_bias = layer_outputs[1] + + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class T5ForConditionalGeneration(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.config = config + self.model_dim = config.d_model + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + self.sampler = Sampler(config.vocab_size) + + + # Only run decoder in the forward pass + # We need to get cross_attention_kv_cache first by + # calling model.prepare(...) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + cross_attention_kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> SamplerOutput: + + decoder_outputs = self.decoder( + input_ids=input_ids, + kv_caches=kv_caches, + input_metadata=input_metadata, + cache_events=cache_events, + cross_attention_kv_caches=cross_attention_kv_caches) + + sequence_output = decoder_outputs + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + sequence_output = sequence_output * (self.model_dim**-0.5) + + next_tokens = self.sampler( + self.shared.weight, sequence_output, positions) + + return next_tokens + + def prepare( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata) -> torch.Tensor: + + encoder_outputs = self.encoder( + input_ids=input_ids, + kv_caches=None, + input_metadata=input_metadata, + cache_events=None, + cross_attention_kv_caches=None) + + cross_attention_kv_caches: List[torch.Tensor] = [] + + for block in self.decoder.block: + cross_attention_layer = block.layer[1] + k = cross_attention_layer.EncDecAttention.k(encoder_outputs) + v = cross_attention_layer.EncDecAttention.v(encoder_outputs) + cross_attention_kv_caches.append(torch.stack([k, v], dim=1)) + + cross_attention_kv_caches_tensor = torch.stack( + cross_attention_kv_caches, dim=0).transpose(0, 1) + + return cross_attention_kv_caches_tensor + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + + state_dict = self.state_dict() + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if 'EncDecAttention.relative_attention_bias' in name: + continue + assert name in state_dict + + loaded_weight = convert_pyslice_to_tensor(loaded_weight) + param = state_dict[name] + assert param.shape == loaded_weight.shape, ( + f"{name} shape mismatch between model and checkpoint: " + f"{param.shape} vs {loaded_weight.shape}") + + param.data.copy_(loaded_weight) From f009f94ffd538e45248668b88997250e112e4ebd Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Sat, 23 Mar 2024 20:34:58 +0000 Subject: [PATCH 2/8] update modeling code Co-authored-by: Jin Shang --- aphrodite/modeling/models/t5.py | 646 +++++++++++++++++--------------- 1 file changed, 336 insertions(+), 310 deletions(-) diff --git a/aphrodite/modeling/models/t5.py b/aphrodite/modeling/models/t5.py index 35f3a52b63..39509360c9 100644 --- a/aphrodite/modeling/models/t5.py +++ b/aphrodite/modeling/models/t5.py @@ -22,7 +22,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only T5 model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional, Tuple + import math import copy @@ -32,184 +33,223 @@ from aphrodite.modeling.metadata import InputMetadata from aphrodite.modeling.layers.activation import get_act_fn -from aphrodite.modeling.layers.attention import PagedAttention +from aphrodite.modeling.layers.enc_dec_attention import ( + EncoderAttention, + DecoderAttention, + CrossAttention, +) +from aphrodite.modeling.layers.linear import ( + ColumnParallelLinear, + LinearMethodBase, + RowParallelLinear, +) from aphrodite.modeling.layers.sampler import Sampler +from aphrodite.modeling.megatron.parallel_state import ( + get_tensor_model_parallel_world_size, ) from aphrodite.modeling.sampling_metadata import SamplingMetadata -from aphrodite.modeling.hf_downloader import (hf_model_weights_iterator, - convert_pyslice_to_tensor) -from aphrodite.common.sequence import SamplerOutput - -from aphrodite.common.logger import init_logger +from aphrodite.modeling.hf_downloader import ( + default_weight_loader, + hf_model_weights_iterator, +) KVCache = Tuple[torch.Tensor, torch.Tensor] -logger = init_logger(__name__) class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction + of mean. + """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - def forward( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states / torch.sqrt(variance + self.variance_epsilon) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # T5 uses a layer_norm which only scales and doesn't shift, which is + # also known as Root Mean Square Layer Normalization + # https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that + # the accumulation for half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, + keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states -class T5DenseAct(nn.Module): +class T5DenseActDense(nn.Module): + def __init__(self, config: T5Config): super().__init__() - self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) - self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.wi = ColumnParallelLinear(config.d_model, config.d_ff, bias=False) + self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False) self.act = get_act_fn(config.dense_act_fn) def forward(self, hidden_states): - hidden_states = self.wi(hidden_states) + hidden_states, _ = self.wi(hidden_states) hidden_states = self.act(hidden_states) - hidden_states = self.wo(hidden_states) + hidden_states, _ = self.wo(hidden_states) return hidden_states - -class T5DenseGatedAct(nn.Module): + +class T5DenseGatedActDense(nn.Module): + def __init__(self, config: T5Config): super().__init__() - self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) - self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) - self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.wi_0 = ColumnParallelLinear(config.d_model, + config.d_ff, + bias=False) + self.wi_1 = ColumnParallelLinear(config.d_model, + config.d_ff, + bias=False) + self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False) self.act = get_act_fn(config.dense_act_fn) def forward(self, hidden_states): - hidden_gelu = self.act(self.wi_0(hidden_states)) - hidden_linear = self.wi_1(hidden_states) + hidden_gelu = self.act(self.wi_0(hidden_states)[0]) + hidden_linear, _ = self.wi_1(hidden_states) hidden_states = hidden_gelu * hidden_linear - hidden_states = self.wo(hidden_states) + hidden_states, _ = self.wo(hidden_states) return hidden_states - + class T5LayerFF(nn.Module): + def __init__(self, config: T5Config): super().__init__() if config.is_gated_act: - self.DenseReluDense = T5DenseGatedAct(config) + self.DenseReluDense = T5DenseGatedActDense(config) else: - self.DenseReluDense = T5DenseAct(config) + self.DenseReluDense = T5DenseActDense(config) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) - + def forward(self, hidden_states): forwarded_states = self.layer_norm(hidden_states) forwarded_states = self.DenseReluDense(forwarded_states) hidden_states = hidden_states + forwarded_states return hidden_states - + class T5Attention(nn.Module): + def __init__( - self, - config: T5Config, - has_relative_attention_bias: bool, - is_cross: bool): + self, + config: T5Config, + is_cross: bool, + has_relative_attention_bias: bool, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.is_decoder = config.is_decoder - self.has_relative_attention_bias = has_relative_attention_bias self.relative_attention_num_buckets = config.relative_attention_num_buckets self.relative_attention_max_distance = config.relative_attention_max_distance self.d_model = config.d_model self.key_value_proj_dim = config.d_kv - self.n_heads = config.num_heads + total_num_heads = config.num_heads + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( + ) + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.n_heads = total_num_heads // tensor_model_parallel_world_size self.inner_dim = self.n_heads * self.key_value_proj_dim - self.is_cross = is_cross - self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) - self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) - self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) - self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + self.q = ColumnParallelLinear(self.d_model, self.inner_dim, bias=False) + self.k = ColumnParallelLinear(self.d_model, self.inner_dim, bias=False) + self.v = ColumnParallelLinear(self.d_model, self.inner_dim, bias=False) + self.o = RowParallelLinear(self.inner_dim, self.d_model, bias=False) - if self.has_relative_attention_bias: - self.has_relative_attention_bias = nn.Embedding( + if has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding( self.relative_attention_num_buckets, self.n_heads) - - self.paged_attn = PagedAttention( - self.n_heads, self.key_value_proj_dim, scale=1) - - + + self.is_cross = is_cross + if self.is_decoder: + if self.is_cross: + self.attn = CrossAttention(self.n_heads, + self.key_value_proj_dim, 1) + else: + self.attn = DecoderAttention(self.n_heads, + self.key_value_proj_dim, 1) + else: + self.attn = EncoderAttention(self.n_heads, self.key_value_proj_dim, + 1) + @staticmethod - def _relative_position_bucket( - relative_position, - bidirectional=True, - num_buckets=32, - max_distance=128): - """Adapted from Mesh Tensorflow: + def _relative_position_bucket(relative_position, + bidirectional=True, + num_buckets=32, + max_distance=128): + """ + Adapted from Mesh Tensorflow: https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - Translate relative position to a bucket number for relative attention. - The relative position is defined as memory_position - query_position, i.e. - the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are - invalid. We use smaller buckets for small absolute relative_position and - larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions - <=-max_distance map to the same bucket. This should allow for more graceful - generalization to longer sequences than the model has been trained on. - + The relative position is defined as memory_position - query_position, + i.e. the distance in tokens from the attending position to the + attended-to position. If bidirectional=False, then positive relative + positions are invalid. We use smaller buckets for small absolute + relative_position and larger buckets for larger absolute + relative_positions. All relative positions >=max_distance map to the + same bucket. All relative positions <=-max_distance map to the same + bucket. This should allow for more graceful generalization to longer + sequences than the model has been trained on Args: relative_position: an int32 Tensor bidirectional: a boolean - whether the attention is bidirectional num_buckets: an integer max_distance: an integer - Returns: a Tensor with the same shape as relative_position, containing int32 - values in the range [0, num_buckets) + values in the range [0, num_buckets) """ relative_buckets = 0 if bidirectional: num_buckets //= 2 - relative_buckets += (relative_position > - 0).to(torch.long) * num_buckets + relative_buckets += (relative_position > 0).to( + torch.long) * num_buckets relative_position = torch.abs(relative_position) else: - relative_position = - \ - torch.min(relative_position, - torch.zeros_like(relative_position)) + relative_position = -torch.min(relative_position, + torch.zeros_like(relative_position)) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions max_exact = num_buckets // 2 is_small = relative_position < max_exact - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + # The other half of the buckets are for logarithmically bigger bins in + # positions up to max_distance relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) + torch.log(relative_position.float() / max_exact) / + math.log(max_distance / max_exact) * + (num_buckets - max_exact)).to(torch.long) relative_position_if_large = torch.min( - relative_position_if_large, torch.full_like( - relative_position_if_large, num_buckets - 1) + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), ) - relative_buckets += torch.where(is_small, - relative_position, relative_position_if_large) + relative_buckets += torch.where(is_small, relative_position, + relative_position_if_large) return relative_buckets def compute_bias(self, query_length, key_length): """Compute binned relative position bias""" - context_position = torch.arange( - query_length, dtype=torch.long, device="cuda")[:, None] - memory_position = torch.arange( - key_length, dtype=torch.long, device="cuda")[None, :] - relative_position = memory_position - \ - context_position # shape (query_length, key_length) + context_position = torch.arange(query_length, + dtype=torch.long, + device="cuda")[:, None] + memory_position = torch.arange(key_length, + dtype=torch.long, + device="cuda")[None, :] + relative_position = (memory_position - context_position + ) # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=(not self.is_decoder), @@ -223,168 +263,182 @@ def compute_bias(self, query_length, key_length): return values def forward( - self, - hidden_states: torch.Tensor, - position_bias: Optional[torch.Tensor], - kv_cache: KVCache, - input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], + self, + hidden_states: torch.Tensor, + kv_cache: Optional[KVCache], + input_metadata: InputMetadata, + encoder_hidden_states: Optional[torch.Tensor], ) -> torch.Tensor: - q = self.q(hidden_states) + q, _ = self.q(hidden_states) + batch_size = hidden_states.shape[0] + seq_len = hidden_states.shape[1] + prompt_len = input_metadata.prompt_lens.max().item() + context_len = input_metadata.context_lens.max().item() + context_len = max(context_len, 1) - if not self.is_decoder: - # Encoder self-attn - - def shape(states): - """Projection.""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """Reshape.""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - q = shape(q) - k = shape(self.k(hidden_states)) - v = shape(self.v(hidden_states)) - - if position_bias is None: - assert self.has_relative_attention_bias - position_bias = self.compute_bias( - hidden_states.shape[1], hidden_states.shape[1]) - - scores = torch.matmul(q, k.transpose(-1, -2)) - scores = scores + position_bias + block_size = 16 - attn_weights = nn.functional.softmax( - scores.float(), - dim=-1).type_as(scores) - attn_output = unshape(torch.matmul(attn_weights, v)) - attn_output = self.o(attn_output) + if not self.is_decoder: + assert kv_cache is None + # Encoder self attention, no cache operations + k, _ = self.k(hidden_states) + v, _ = self.v(hidden_states) + + if input_metadata.attn_bias is None: + input_metadata.attn_bias = self.compute_bias( + prompt_len, (prompt_len + block_size - 1) // block_size * + block_size).repeat(batch_size, 1, 1, 1) + for i in range(batch_size): + input_metadata.attn_bias[ + i, :, :, + input_metadata.prompt_lens[i]:, ] = torch.finfo( + input_metadata.attn_bias.dtype).min + + attn_output = self.attn(q, k, v, input_metadata) elif not self.is_cross: - # Decoder self-attn - k = self.k(hidden_states) - v = self.v(hidden_states) + # Decoder self attention + k, _ = self.k(hidden_states) + v, _ = self.v(hidden_states) - if position_bias is None: - assert self.has_relative_attention_bias + if input_metadata.attn_bias is None: position_bias = self.compute_bias( - input_metadata.max_context_len, input_metadata.max_context_len) - + 1 if input_metadata.is_prompt else context_len, + (context_len + block_size - 1) // block_size * + block_size).repeat(batch_size, 1, 1, 1) + input_metadata.attn_bias = position_bias[:, :, + -seq_len:, :].contiguous( + ) + key_cache, value_cache = kv_cache - attn_output = self.paged_attn(q, k, v, key_cache, value_cache, - input_metadata, cache_event) - attn_output = self.o(attn_output) - + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) + else: - # Decoder cross-attn - assert position_bias is None - assert self.has_relative_attention_bias == False - - k, v = kv_cache - scores = torch.matmul(q, k.transpose(-1, -2)) - attn_weights = nn.functional.softmax( - scores.float(), - dim=-1).type_as(scores) - attn_output = torch.matmul(attn_weights, v) - attn_output = self.o(attn_output) - - return (attn_output,) + (position_bias,) - + # Cross attention + + key_cache, value_cache = kv_cache + if input_metadata.is_prompt: + assert encoder_hidden_states is not None + k, _ = self.k(encoder_hidden_states) + v, _ = self.v(encoder_hidden_states) + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) + else: + attn_output = self.attn(q, None, None, key_cache, value_cache, + input_metadata) + + attn_output, _ = self.o(attn_output) + return attn_output + + class T5LayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias): + + def __init__( + self, + config, + has_relative_attention_bias, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.SelfAttention = T5Attention( config, + is_cross=False, has_relative_attention_bias=has_relative_attention_bias, - is_cross=False) + linear_method=linear_method, + ) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) - + def forward( - self, - hidden_states: torch.Tensor, - position_bias: Optional[torch.Tensor], - kv_cache: KVCache, - input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, ) -> torch.Tensor: - normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( hidden_states=normed_hidden_states, - position_bias=position_bias, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event) - hidden_states = hidden_states + attention_output[0] - return (hidden_states,) + (attention_output[1],) + encoder_hidden_states=None, + ) + hidden_states = hidden_states + attention_output + return hidden_states class T5LayerCrossAttention(nn.Module): - def __init__(self, config): + + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.EncDecAttention = T5Attention( config, + is_cross=True, has_relative_attention_bias=False, - is_cross=True) + linear_method=linear_method, + ) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) - + def forward( self, hidden_states: torch.Tensor, - position_bias: Optional[torch.Tensor], - kv_cache: KVCache, + kv_cache: Optional[KVCache], input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], + encoder_hidden_states: Optional[torch.Tensor], ) -> torch.Tensor: - normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( hidden_states=normed_hidden_states, - position_bias=position_bias, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event) - hidden_states = hidden_states + attention_output[0] - return (hidden_states,) + (attention_output[1],) - + encoder_hidden_states=encoder_hidden_states, + ) + hidden_states = hidden_states + attention_output + return hidden_states + class T5Block(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + + def __init__( + self, + config, + has_relative_attention_bias: bool, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(T5LayerSelfAttention( - config, - has_relative_attention_bias=has_relative_attention_bias)) - + self.layer.append( + T5LayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias, + linear_method=linear_method, + )) if self.is_decoder: - self.layer.append(T5LayerCrossAttention(config)) - self.layer.append(T5LayerFF(config)) + self.layer.append( + T5LayerCrossAttention(config, linear_method=linear_method)) + + self.layer.append(T5LayerFF(config)) def forward( self, hidden_states: torch.Tensor, - position_bias: Optional[torch.Tensor], - kv_cache: KVCache, - cross_attention_kv_cache: KVCache, + kv_cache: Optional[KVCache], input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], + encoder_hidden_states: Optional[torch.Tensor], ): - - self_attention_outputs = self.layer[0]( + hidden_states = self.layer[0]( hidden_states=hidden_states, - position_bias=position_bias, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event) - - hidden_states = self_attention_outputs[0] - self_attention_bias = self_attention_outputs[1] + ) if hidden_states.dtype == torch.float16: clamp_value = torch.where( @@ -392,186 +446,158 @@ def forward( torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max, ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value) - + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) + if self.is_decoder: - cross_attention_outputs = self.layer[1]( - hidden_states=hidden_states, - kv_cache=cross_attention_kv_cache, - position_bias=None, + hidden_states = self.layer[1]( + hidden_states, + kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event) - hidden_states = cross_attention_outputs[0] + encoder_hidden_states=encoder_hidden_states, + ) if hidden_states.dtype == torch.float16: clamp_value = torch.where( torch.isinf(hidden_states).any(), torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max, ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value) - - # Apply FF layer + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) + + # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states) - outputs = (hidden_states,) + (self_attention_bias,) - return outputs - + return hidden_states + class T5Stack(nn.Module): + def __init__( - self, - config: T5Config, - embed_tokens: torch.Tensor): - + self, + config: T5Config, + embed_tokens: torch.Tensor, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.embed_tokens = embed_tokens self.block = nn.ModuleList([ - T5Block(config, has_relative_attention_bias=bool(i == 0)) - for i in range(config.num_layers)]) - + T5Block( + config, + has_relative_attention_bias=(i == 0), + linear_method=linear_method, + ) for i in range(config.num_layers) + ]) + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) - + def forward( - self, - input_ids: torch.Tensor, - kv_caches: List[KVCache], - cross_attention_kv_caches: List[KVCache], - input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], + self, + input_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + encoder_hidden_states: Optional[torch.Tensor], ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - position_bias = None for i, layer_module in enumerate(self.block): - if cache_events is None: - cache_event = None - else: - cache_event = cache_events[i] - kv_cache = kv_caches[i] if self.is_decoder else None - cross_attention_kv_cache = cross_attention_kv_caches[i] if self.is_decoder else None layer_outputs = layer_module( hidden_states, - position_bias=position_bias, kv_cache=kv_cache, - cross_attention_kv_cache=cross_attention_kv_cache, input_metadata=input_metadata, - cache_event=cache_event) - - hidden_states = layer_outputs[0] + encoder_hidden_states=encoder_hidden_states, + ) - # We share the position biases between the layers - # The first layer to store them - # layer_outputs = hidden_states, (self-attention position bias,) - position_bias = layer_outputs[1] + hidden_states = layer_outputs hidden_states = self.final_layer_norm(hidden_states) return hidden_states - + class T5ForConditionalGeneration(nn.Module): - def __init__(self, config: T5Config): + + def __init__(self, + config: T5Config, + linear_method: Optional[LinearMethodBase] = None): super().__init__() self.config = config self.model_dim = config.d_model + self.shared = nn.Embedding(config.vocab_size, config.d_model) - + encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = T5Stack(encoder_config, self.shared) + self.encoder = T5Stack(encoder_config, self.shared, linear_method) + decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False - decoder_config.num_layers = config.num_decoder_layers - self.decoder = T5Stack(decoder_config, self.shared) - self.sampler = Sampler(config.vocab_size) + self.decoder = T5Stack(decoder_config, self.shared, linear_method) + self.sampler = Sampler(config.vocab_size) - # Only run decoder in the forward pass - # We need to get cross_attention_kv_cache first by - # calling model.prepare(...) - def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[KVCache], - cross_attention_kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: - - decoder_outputs = self.decoder( - input_ids=input_ids, - kv_caches=kv_caches, - input_metadata=input_metadata, - cache_events=cache_events, - cross_attention_kv_caches=cross_attention_kv_caches) - - sequence_output = decoder_outputs + ) -> torch.Tensor: + if input_metadata.is_prompt: + # prompt run, need to run encoder once + hidden_states = self.encoder(input_ids, kv_caches, input_metadata, + None) + # Clear the attention bias + input_metadata.attn_bias = None + batch_size = input_ids.shape[0] + input_ids = (torch.ones(batch_size, 1, dtype=torch.long) * + self.config.decoder_start_token_id).cuda() + + else: + hidden_states = None + + if kv_caches[0][0] is not None: # Skip decoder for profiling run + hidden_states = self.decoder(input_ids, kv_caches, input_metadata, + hidden_states) if self.config.tie_word_embeddings: # Rescale output before projecting on vocab - sequence_output = sequence_output * (self.model_dim**-0.5) + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + hidden_states = hidden_states * (self.model_dim**-0.5) - next_tokens = self.sampler( - self.shared.weight, sequence_output, positions) - - return next_tokens + return hidden_states - def prepare( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_metadata: InputMetadata) -> torch.Tensor: - - encoder_outputs = self.encoder( - input_ids=input_ids, - kv_caches=None, - input_metadata=input_metadata, - cache_events=None, - cross_attention_kv_caches=None) - - cross_attention_kv_caches: List[torch.Tensor] = [] - - for block in self.decoder.block: - cross_attention_layer = block.layer[1] - k = cross_attention_layer.EncDecAttention.k(encoder_outputs) - v = cross_attention_layer.EncDecAttention.v(encoder_outputs) - cross_attention_kv_caches.append(torch.stack([k, v], dim=1)) - - cross_attention_kv_caches_tensor = torch.stack( - cross_attention_kv_caches, dim=0).transpose(0, 1) - - return cross_attention_kv_caches_tensor - - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): - - state_dict = self.state_dict() + def sample(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata): + next_tokens = self.sampler(self.shared.weight, hidden_states, + sampling_metadata) + return next_tokens + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): - if 'EncDecAttention.relative_attention_bias' in name: + model_name_or_path, cache_dir, load_format, revision): + if "EncDecAttention.relative_attention_bias" in name: continue - assert name in state_dict - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - param = state_dict[name] + assert name in params_dict, f"{name} not in params_dict" + param = params_dict[name] assert param.shape == loaded_weight.shape, ( f"{name} shape mismatch between model and checkpoint: " - f"{param.shape} vs {loaded_weight.shape}") - - param.data.copy_(loaded_weight) + f"{param.shape} != {loaded_weight.shape}") + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + From a788ca33bf452625bca2e127ce64c2c777e370ce Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Sat, 23 Mar 2024 20:46:00 +0000 Subject: [PATCH 3/8] hack in custom bias for attention kernels --- aphrodite/modeling/models/__init__.py | 1 + kernels/attention/attention_kernels.cu | 40 +++++++++++++++++++------- kernels/ops.h | 2 ++ 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/aphrodite/modeling/models/__init__.py b/aphrodite/modeling/models/__init__.py index 091a20c334..9c22f3540c 100644 --- a/aphrodite/modeling/models/__init__.py +++ b/aphrodite/modeling/models/__init__.py @@ -44,6 +44,7 @@ "RWForCausalLM": ("falcon", "FalconForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), + "T5ForConditionalGeneration": ("t5", "T5ForConditionalGeneration"), } # Models not supported by ROCm. diff --git a/kernels/attention/attention_kernels.cu b/kernels/attention/attention_kernels.cu index 00c14a21ae..3054a83576 100644 --- a/kernels/attention/attention_kernels.cu +++ b/kernels/attention/attention_kernels.cu @@ -23,6 +23,7 @@ #include #include #include +#include #include "attention_dtypes.h" #include "attention_utils.cuh" @@ -110,6 +111,7 @@ __device__ void paged_attention_kernel( const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, max_seq_len] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -154,6 +156,10 @@ __device__ void paged_attention_kernel( const int num_queries_per_kv = num_heads / num_kv_heads; const int kv_head_idx = head_idx / num_queries_per_kv; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + const float* custom_bias_vec = custom_bias == nullptr + ? nullptr + : custom_bias + seq_idx * num_kv_heads * num_context_blocks * BLOCK_SIZE + + kv_head_idx * num_context_blocks * BLOCK_SIZE; // A vector type to store a part of a key or a query. // The vector size is configured in such a way that the threads in a thread group @@ -246,8 +252,10 @@ __device__ void paged_attention_kernel( // Compute dot product. // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); - // Add the ALiBi bias if slopes are given. - qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + // Add the custom or ALiBi bias if given. + qk += (custom_bias_vec != nullptr) ? custom_bias_vec[token_idx] + : (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) + : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. @@ -458,6 +466,7 @@ __global__ void paged_attention_v1_kernel( const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, seq_len] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -468,7 +477,7 @@ __global__ void paged_attention_v1_kernel( paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp); + max_num_blocks_per_seq, alibi_slopes, custom_bias, q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp); } // Grid: (num_heads, num_seqs, max_num_partitions). @@ -493,6 +502,7 @@ __global__ void paged_attention_v2_kernel( const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, seq_len] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -502,7 +512,7 @@ __global__ void paged_attention_v2_kernel( const float v_zp) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, + block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, custom_bias, q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp); } @@ -623,6 +633,7 @@ __global__ void paged_attention_v2_reduce_kernel( context_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ + custom_bias_ptr, \ q_stride, \ kv_block_stride, \ kv_head_stride, \ @@ -649,6 +660,7 @@ void paged_attention_v1_launcher( torch::Tensor& context_lens, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const float k_scale, const float k_zp, const float v_scale, @@ -665,9 +677,10 @@ void paged_attention_v1_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; + + const float* custom_bias_ptr = custom_bias ? reinterpret_cast(custom_bias.value().data_ptr()) : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); @@ -728,6 +741,7 @@ void paged_attention_v1_launcher( context_lens, \ max_context_len, \ alibi_slopes, \ + custom_bias, \ k_scale, \ k_zp, \ v_scale, \ @@ -763,6 +777,7 @@ void paged_attention_v1( int block_size, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const std::string& kv_cache_dtype, const float k_scale = 1.0f, const float k_zp = 0.0f, @@ -821,6 +836,7 @@ void paged_attention_v1( context_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ + custom_bias_ptr, \ q_stride, \ kv_block_stride, \ kv_head_stride, \ @@ -858,6 +874,7 @@ void paged_attention_v2_launcher( torch::Tensor& context_lens, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const float k_scale, const float k_zp, const float v_scale, @@ -874,9 +891,10 @@ void paged_attention_v2_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; + + const float* custom_bias_ptr = custom_bias ? reinterpret_cast(custom_bias.value().data_ptr()) : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -946,6 +964,7 @@ void paged_attention_v2_launcher( context_lens, \ max_context_len, \ alibi_slopes, \ + custom_bias, \ k_scale, \ k_zp, \ v_scale, \ @@ -984,6 +1003,7 @@ void paged_attention_v2( int block_size, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const std::string& kv_cache_dtype, const float k_scale = 1.0f, const float k_zp = 0.0f, diff --git a/kernels/ops.h b/kernels/ops.h index 634e155c33..b1769c4cf8 100644 --- a/kernels/ops.h +++ b/kernels/ops.h @@ -14,6 +14,7 @@ void paged_attention_v1( int block_size, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const std::string& kv_cache_dtype, float k_scale = 1.0f, float k_zp = 0.0f, @@ -35,6 +36,7 @@ void paged_attention_v2( int block_size, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const std::string& kv_cache_dtype, float k_scale = 1.0f, float k_zp = 0.0f, From 58e89e29d97dd84afd479f950703b0880217e376 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Sat, 23 Mar 2024 21:02:51 +0000 Subject: [PATCH 4/8] add custom bias to attention.py --- aphrodite/modeling/layers/attention.py | 52 +++++++++++++------------- aphrodite/modeling/metadata.py | 1 + 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/aphrodite/modeling/layers/attention.py b/aphrodite/modeling/layers/attention.py index 89e981d892..d7d89b401e 100644 --- a/aphrodite/modeling/layers/attention.py +++ b/aphrodite/modeling/layers/attention.py @@ -226,17 +226,12 @@ def forward( ) else: - # Decoding run. - output = _paged_attention( - query, - key_cache, - value_cache, - input_metadata, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - kv_quant_param, - ) + # Decoding run + output = paged_attention( + query, key_cache, value_cache, input_metadata.block_tables, + input_metadata.context_lens, input_metadata.max_context_len, + self.num_kv_heads, self.scale, self.alibi_slopes, + kv_quant_param, None, input_metadata.kv_cache_dtype) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) @@ -276,23 +271,26 @@ def _make_alibi_bias( return attn_bias -def _paged_attention( +def paged_attention( query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, - input_metadata: InputMetadata, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], + custom_bias: Optional[torch.Tensor], + kv_cache_dtype: torch.dtype, kv_quant_param: List[float], ) -> torch.Tensor: output = torch.empty_like(query) block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape - max_num_partitions = ( - (input_metadata.max_context_len + _PARTITION_SIZE - 1) // - _PARTITION_SIZE) + max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) # NOTE: We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -300,8 +298,8 @@ def _paged_attention( # to parallelize. # TODO: Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = input_metadata.max_context_len <= 8192 and ( - max_num_partitions == 1 or num_seqs * num_heads > 512) + use_v1 = max_context_len <= 8192 and (max_num_partitions == 1 + or num_seqs * num_heads > 512) if use_v1: # Run PagedAttention V1. ops.paged_attention_v1( @@ -311,12 +309,13 @@ def _paged_attention( value_cache, num_kv_heads, scale, - input_metadata.block_tables, - input_metadata.context_lens, + block_tables, + context_lens, block_size, - input_metadata.max_context_len, + max_context_len, alibi_slopes, - input_metadata.kv_cache_dtype, + custom_bias, + kv_cache_dtype, *kv_quant_param, ) else: @@ -343,12 +342,13 @@ def _paged_attention( value_cache, num_kv_heads, scale, - input_metadata.block_tables, - input_metadata.context_lens, + block_tables, + context_lens, block_size, - input_metadata.max_context_len, + max_context_len, alibi_slopes, - input_metadata.kv_cache_dtype, + custom_bias, + kv_cache_dtype, *kv_quant_param, ) return output diff --git a/aphrodite/modeling/metadata.py b/aphrodite/modeling/metadata.py index b954d76a5b..766dc87ae6 100644 --- a/aphrodite/modeling/metadata.py +++ b/aphrodite/modeling/metadata.py @@ -49,6 +49,7 @@ def __init__( def __repr__(self) -> str: return ("InputMetadata(" f"is_prompt={self.is_prompt}, " + f"prompt_lens={self.prompt_lens}, " f"max_context_len={self.max_context_len}, " f"slot_mapping={self.slot_mapping}, " f"context_lens={self.context_lens}, " From 3ed4cc431cd0b3b43eecd4a199d36983f99ae822 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Sat, 23 Mar 2024 21:06:10 +0000 Subject: [PATCH 5/8] enc_dec attention code --- .../modeling/layers/enc_dec_attention.py | 240 ++++++++++++++++++ 1 file changed, 240 insertions(+) create mode 100644 aphrodite/modeling/layers/enc_dec_attention.py diff --git a/aphrodite/modeling/layers/enc_dec_attention.py b/aphrodite/modeling/layers/enc_dec_attention.py new file mode 100644 index 0000000000..bca9a0aef3 --- /dev/null +++ b/aphrodite/modeling/layers/enc_dec_attention.py @@ -0,0 +1,240 @@ +"""Multi-head attention for encoder-decoder models.""" +from typing import Optional + +import torch +import torch.nn as nn +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import ( + BlockDiagonalCausalMask, ) + +from aphrodite._C import cache_ops +from aphrodite.modeling.metadata import InputMetadata +from aphrodite.common.utils import is_hip +from aphrodite.modeling.layers.attention import paged_attention + +_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] + + +class EncDecAttention(nn.Module): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + + if self.head_size not in _SUPPORTED_HEAD_SIZES: + raise ValueError(f"head_size ({self.head_size}) is not supported. " + f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") + + +class EncoderAttention(EncDecAttention): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + ) -> None: + super().__init__(num_heads, head_size, scale) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + """Encoder attention forward pass. + Args: + query: Query tensor. + key: Key tensor. + value: Value tensor. + custom_bias: Custom bias tensor. + Returns: + Output tensor. + """ + # query: [batch_size, seq_len, num_heads * head_size] + # key: [batch_size, seq_len, num_heads * head_size] + # value: [batch_size, seq_len, num_heads * head_size] + # custom_bias: [batch_size, seq_len, seq_len] + # output: [batch_size, seq_len, num_heads * head_size] + + assert input_metadata.is_prompt + batch_size, seq_len, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + key = key.view(batch_size, seq_len, self.num_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_heads, self.head_size) + if input_metadata.attn_bias is None: + input_metadata.attn_bias = BlockDiagonalCausalMask.from_seqlens( + [seq_len] * batch_size) + + input_metadata.attn_bias = input_metadata.attn_bias[:, :, :, :seq_len] + + # Normal attention + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, + ) + output = out.view(batch_size, seq_len, hidden_size) + return output + + +class DecoderAttention(EncDecAttention): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + ) -> None: + super().__init__(num_heads, head_size, scale) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + input_metadata: InputMetadata, + ): + """Decoder attention forward pass. + Args: + query: Query tensor. + key: Key tensor. + value: Value tensor. + key_cache: Key cache tensor. + value_cache: Value cache tensor. + custom_bias: Custom bias tensor. + Returns: + Output tensor. + """ + + batch_size, seq_len, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_heads, self.head_size) + value = value.view(-1, self.num_heads, self.head_size) + # Reshape the keys and values and store them in the cache. + # If key_cache and value_cache are not provided, the new key and value + # vectors will not be cached. This happens during the initial memory + # profiling run. + if key_cache is not None and value_cache is not None: + + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, + input_metadata.slot_mapping[:, -1].flatten().contiguous(), + input_metadata.kv_cache_dtype) + + max_prompt_len = input_metadata.prompt_lens.max().item() + block_size = value_cache.shape[3] + prompt_table_len = (max_prompt_len + block_size - 1) // block_size + block_tables = input_metadata.block_tables[:, + prompt_table_len:].contiguous( + ) + output = paged_attention( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_tables, + context_lens=input_metadata.context_lens, + max_context_len=input_metadata.max_context_len, + num_kv_heads=self.num_heads, + scale=self.scale, + alibi_slopes=None, + custom_bias=input_metadata.attn_bias.to(torch.float32), + kv_cache_dtype=input_metadata.kv_cache_dtype, + kv_quant_param=input_metadata.kv_quant_params, + ) + return output.view(batch_size, seq_len, hidden_size) + + +class CrossAttention(EncDecAttention): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + ) -> None: + super().__init__(num_heads, head_size, scale) + + def forward( + self, + query: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + input_metadata: InputMetadata, + ): + """Cross attention forward pass. + Args: + query: Query tensor. + key_cache: Key cache tensor. + value_cache: Value cache tensor. + input_metadata: Input metadata. + key: Key tensor. Only needed in the first pass. + value: Value tensor. Only needed in the first pass. + custom_bias: Custom bias tensor. + Returns: + Output tensor. + """ + batch_size, seq_len, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + key = key.view(-1, self.num_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_heads, self.head_size) + + # Reshape the keys and values and store them in the cache. + # It only happens during the first pass. + if (input_metadata.is_prompt and key_cache is not None + and value_cache is not None): + assert key is not None and value is not None + cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + input_metadata.slot_mapping[:, :-1].flatten().contiguous(), + input_metadata.kv_cache_dtype, + ) + + max_prompt_len = input_metadata.prompt_lens.int().max().item() + block_size = value_cache.shape[3] + prompt_table_len = (max_prompt_len + block_size - 1) // block_size + block_tables = input_metadata.block_tables[:, : + prompt_table_len].contiguous( + ) + + output = paged_attention( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_tables, + context_lens=input_metadata.prompt_lens.int(), + max_context_len=max_prompt_len, + num_kv_heads=self.num_heads, + scale=self.scale, + alibi_slopes=None, + custom_bias=None, + kv_cache_dtype=input_metadata.kv_cache_dtype, + kv_quant_param=input_metadata.kv_quant_params, + ) + + return output.view(batch_size, seq_len, hidden_size) From 72659e5cad4625bfd2fe14764764de8118fd2956 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Sat, 23 Mar 2024 21:08:15 +0000 Subject: [PATCH 6/8] separate prompt and genned tokens for enc-dec --- aphrodite/common/sequence.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/aphrodite/common/sequence.py b/aphrodite/common/sequence.py index 60e0116d86..a497b5f639 100644 --- a/aphrodite/common/sequence.py +++ b/aphrodite/common/sequence.py @@ -142,6 +142,7 @@ def __init__( prompt: str, prompt_token_ids: List[int], block_size: int, + is_encoder_decoder: bool, lora_request: Optional[LoRARequest] = None, ) -> None: self.seq_id = seq_id @@ -154,8 +155,20 @@ def __init__( self.output_text = "" self.logical_token_blocks: List[LogicalTokenBlock] = [] + initial_token_ids = prompt_token_ids + if is_encoder_decoder: + # We need to separate the prompt and generated tokens for + # encoder-decoder models. + num_prompt_blocks = (len(prompt_token_ids) + block_size - + 1) // block_size + padded_prompt_len = num_prompt_blocks * block_size + initial_token_ids = prompt_token_ids + [0] * ( + padded_prompt_len - len(prompt_token_ids)) + # Also need to append decoder_start_token_id + initial_token_ids.append(0) + # Initialize the logical token blocks with the prompt token ids. - self._append_tokens_to_blocks(prompt_token_ids) + self._append_tokens_to_blocks(initial_token_ids) self.status = SequenceStatus.WAITING # Used for incremental detokenization From f9726a364967329fc28d29f787b0a9097b35e336 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Sat, 23 Mar 2024 21:23:09 +0000 Subject: [PATCH 7/8] hell --- aphrodite/task_handler/model_runner.py | 98 ++++++++++++++++++++++---- 1 file changed, 84 insertions(+), 14 deletions(-) diff --git a/aphrodite/task_handler/model_runner.py b/aphrodite/task_handler/model_runner.py index 26f08edde8..83f97fc681 100644 --- a/aphrodite/task_handler/model_runner.py +++ b/aphrodite/task_handler/model_runner.py @@ -92,6 +92,11 @@ def __init__( self.kv_quant_params = (self.load_kv_quant_params( model_config, kv_quant_params_path) if self.kv_cache_dtype == "int8" else None) + + # Unpack HF is_encoder_decoder config attribute + self.is_encoder_decoder = False if self.model_config is None else \ + getattr(self.model_config.hf_config, + "is_encoder_decoder", False) def load_kv_quant_params(self, model_config: ModelConfig, kv_quant_params_path: str) -> List[List[float]]: @@ -175,6 +180,8 @@ def _prepare_prompt( context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] + block_tables: List[List[int]] = [] + max_block_table_len = 0 for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -248,6 +255,10 @@ def _prepare_prompt( block_offset = i % self.block_size slot = block_number * self.block_size + block_offset slot_mapping[-1].append(slot) + + if self.is_encoder_decoder: + block_tables.append(block_table) + max_block_table_len = max(max_block_table_len, len(block_table)) max_prompt_len = max(subquery_lens) input_tokens = _make_tensor_with_pad( @@ -264,9 +275,17 @@ def _prepare_prompt( dtype=torch.long, device=self.device, ) + if self.is_encoder_decoder and len(block_tables) > 0: + # Pad the slot mapping to the same length and add decoder_start_id + for i in range(len(slot_mapping)): + slot_mapping[i] += [_PAD_SLOT_ID + ] * (max_prompt_len - len(slot_mapping[i])) + slot_mapping[i].append(block_tables[i][-1] * self.block_size) + + max_slot_mapping_len = max_prompt_len + self.is_encoder_decoder slot_mapping = _make_tensor_with_pad( slot_mapping, - max_prompt_len, + max_slot_mapping_len, pad=_PAD_SLOT_ID, dtype=torch.long, device=self.device, @@ -280,13 +299,31 @@ def _prepare_prompt( device=self.device) # Prepare prefix block tables max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) - block_tables = _make_tensor_with_pad( - prefix_block_tables, - max_len=max_prompt_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) + if self.is_encoder_decoder: + padded_block_tables = [] + # Pad the encoder block tables to the same length and then add a + # decoder block table in the end + for block_table in block_tables: + block_table = block_table[:-1] + [0] * ( + max_block_table_len - len(block_table)) + block_table[-1:] + padded_block_tables.append(block_table) + + block_tables_tensor = _make_tensor_with_pad( + padded_block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device) + else: + # Prepare prefix block tables + max_prompt_block_table_len = max( + len(t) for t in prefix_block_tables) + block_tables_tensor = _make_tensor_with_pad( + prefix_block_tables, + max_len=max_prompt_block_table_len, + pad=0, + dtype=torch.int, + device=self.device) start_loc_tensor = torch.arange( 0, len(prompt_lens) * max_prompt_len, @@ -304,9 +341,9 @@ def _prepare_prompt( prompt_lens=prompt_lens_tensor, max_seq_len=max_prompt_len, start_loc=start_loc_tensor, - max_context_len=None, + max_context_len=max(context_lens), context_lens=context_lens_tensor, - block_tables=block_tables, + block_tables=block_tables_tensor, use_cuda_graph=False, kv_cache_dtype=self.kv_cache_dtype, kv_quant_params=self.kv_quant_params, @@ -331,12 +368,14 @@ def _prepare_decode( input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] + prompt_lens: List[int] = [] context_lens: List[int] = [] block_tables: List[List[int]] = [] lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + max_block_table_len = 0 for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt @@ -355,11 +394,28 @@ def _prepare_decode( position = seq_len - 1 input_positions.append([position]) - context_len = (seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window)) + prompt_len = len(seq_data.prompt_token_ids) + prompt_lens.append(prompt_len) + + if self.is_encoder_decoder: + # Encoder-decoder model stores prompt and generation tokens + # separately, so we need to adjust to the pad. + prompt_blocks_num = (prompt_len + self.block_size - + 1) // self.block_size + prompt_pad = prompt_blocks_num * self.block_size - prompt_len + position += prompt_pad + 1 # One extra for decoder_start_id + + if self.is_encoder_decoder: + context_len = seq_len - prompt_len + 1 + elif self.sliding_window is not None: + context_len = min(seq_len, self.sliding_window) + else: + context_len = seq_len context_lens.append(context_len) block_table = seq_group_metadata.block_tables[seq_id] + max_block_table_len = max(max_block_table_len, + len(block_table)) block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset @@ -372,6 +428,16 @@ def _prepare_decode( self.block_size) block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + if self.is_encoder_decoder: + padded_block_tables = [] + # Pad the encoder block tables to the same length and then add a + # decoder block table in the end + for block_table in block_tables: + block_table = block_table[:-1] + [0] * ( + max_block_table_len - len(block_table)) + block_table[-1:] + padded_block_tables.append(block_table) + + block_tables = padded_block_tables batch_size = len(input_tokens) max_context_len = max(context_lens) @@ -379,6 +445,10 @@ def _prepare_decode( not self.model_config.enforce_eager and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and max_context_len <= self.max_context_len_to_capture) + + prompt_lens = torch.tensor(prompt_lens, + dtype=torch.int, + device=self.device) if use_captured_graph: # Pad the input tokens, positions, and slot mapping to match the # batch size of the captured graph. @@ -441,7 +511,7 @@ def _prepare_decode( input_metadata = InputMetadata( is_prompt=False, slot_mapping=slot_mapping, - prompt_lens=None, + prompt_lens=prompt_lens, max_seq_len=None, start_loc=None, max_context_len=max_context_len, @@ -479,7 +549,7 @@ def _prepare_sample( sampling_params = seq_group_metadata.sampling_params seq_groups.append((seq_ids, sampling_params)) - if seq_group_metadata.is_prompt: + if seq_group_metadata.is_prompt and not self.is_encoder_decoder: assert len(seq_ids) == 1 assert subquery_lens is not None subquery_len = subquery_lens[i] From c3b15f0926dd24885d9f780f01a451d26fd52653 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Sat, 23 Mar 2024 21:27:07 +0000 Subject: [PATCH 8/8] do not allow context shift for enc-dec --- aphrodite/endpoints/llm.py | 4 ++++ aphrodite/engine/aphrodite_engine.py | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/aphrodite/endpoints/llm.py b/aphrodite/endpoints/llm.py index 18cdb3baa4..b48444666b 100644 --- a/aphrodite/endpoints/llm.py +++ b/aphrodite/endpoints/llm.py @@ -156,6 +156,10 @@ def generate( if sampling_params is None: # Use default sampling params. sampling_params = SamplingParams() + + if self.llm_engine.is_encoder_decoder: + assert (self.llm_engine.cache_config.context_shift is None + ), "Encoder-decoder models do not support context shift." # Add requests to the engine. num_requests = len(prompts) if prompts is not None else len( diff --git a/aphrodite/engine/aphrodite_engine.py b/aphrodite/engine/aphrodite_engine.py index b811bfeecc..fc3041dc09 100644 --- a/aphrodite/engine/aphrodite_engine.py +++ b/aphrodite/engine/aphrodite_engine.py @@ -125,6 +125,9 @@ def __init__( self.stat_logger = StatLogger( local_interval=_LOCAL_LOGGING_INTERVAL_SEC, labels=dict(model_name=model_config.model)) + + self.is_encoder_decoder = getattr(self.model_config.hf_config, + "is_encoder_decoder", False) def get_tokenizer_for_seq(self, sequence: Sequence): return self.tokenizer.get_lora_tokenizer(sequence.lora_request) @@ -462,7 +465,7 @@ def add_request( block_size = self.cache_config.block_size seq_id = next(self.seq_counter) seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - lora_request) + self.is_encoder_decoder, lora_request) # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params,