From 169b530607c0102fdb02ce1fd3323fd6085477b0 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 14 Oct 2024 20:24:25 -0400 Subject: [PATCH] [Bugfix] Clean up some cruft in mamba.py (#9343) --- docs/source/models/supported_models.rst | 2 +- vllm/model_executor/models/mamba.py | 113 +++--------------------- 2 files changed, 11 insertions(+), 104 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 926ffab6d9287..102842b0a188d 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -155,7 +155,7 @@ Text Generation * - :code:`MambaForCausalLM` - Mamba - :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc. - - ✅︎ + - - * - :code:`MiniCPMForCausalLM` - MiniCPM diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 1112a2181135a..b86b687a9c361 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,6 +1,5 @@ # coding=utf-8 """PyTorch MAMBA model.""" -from dataclasses import dataclass from typing import Iterable, List, Optional, Tuple import torch @@ -10,7 +9,6 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -39,13 +37,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] -@dataclass -class MambaCacheParams: - is_prompt: bool = False - conv_state: torch.Tensor = torch.Tensor() - ssm_state: torch.Tensor = torch.Tensor() - - # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer class MambaMixer(nn.Module): """ @@ -209,37 +200,6 @@ def forward(self, hidden_states: torch.Tensor, return contextualized_states -class MambaMLP(nn.Module): - - def __init__( - self, - config: MambaConfig, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - hidden_size = config.hidden_size - intermediate_size = config.intermediate_size - hidden_act = config.hidden_act - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - class MambaDecoderLayer(nn.Module): def __init__(self, @@ -252,7 +212,6 @@ def __init__(self, self.config = config self.mixer = MambaMixer(config, layer_idx) - self.feed_forward = MambaMLP(config, quant_config=quant_config) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -274,10 +233,6 @@ def forward( hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, ssm_state) - # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) - hidden_states = self.feed_forward(hidden_states) return hidden_states, residual @@ -319,7 +274,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, conv_state: torch.Tensor, ssm_state: torch.Tensor, @@ -346,26 +300,6 @@ def forward( class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - } - - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "embed_tokens", - "lm_head", - ] - embedding_modules = { - "embeddings": "input_embeddings", - "lm_head": "output_embeddings", - } - embedding_padding_modules = ["lm_head"] def __init__( self, @@ -416,8 +350,8 @@ def forward(self, mamba_cache_tensors = self.mamba_cache.current_run_tensors( input_ids, attn_metadata, **kwargs) - hidden_states = self.backbone(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_tensors[0], + hidden_states = self.backbone(input_ids, positions, attn_metadata, + mamba_cache_tensors[0], mamba_cache_tensors[1]) return hidden_states @@ -457,43 +391,16 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if "A_log" in name: name = name.replace("A_log", "A") - if ".self_attn." in name: - name = name.replace(".self_attn", "") - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)