Skip to content

Commit

Permalink
Add cache_config for DeepseekV2
Browse files Browse the repository at this point in the history
  • Loading branch information
seungduk-yanolja authored May 27, 2024
1 parent ca72192 commit 688606f
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
Expand Down Expand Up @@ -197,6 +198,7 @@ def __init__(
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
layer_idx=None,
) -> None:
Expand Down Expand Up @@ -276,7 +278,8 @@ def __init__(
self.attn = Attention(self.num_local_heads,
256,
self.scaling,
num_kv_heads=self.num_local_heads)
num_kv_heads=self.num_local_heads,
cache_config=cache_config)

def forward(
self,
Expand Down Expand Up @@ -333,6 +336,7 @@ def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
Expand All @@ -354,6 +358,7 @@ def __init__(
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
layer_idx=layer_idx,
)
Expand Down Expand Up @@ -409,6 +414,7 @@ class DeepseekV2Model(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
Expand All @@ -422,6 +428,7 @@ def __init__(
self.layers = nn.ModuleList([
DeepseekV2DecoderLayer(config,
layer_idx,
cache_config=cache_config,
quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
Expand Down Expand Up @@ -450,12 +457,13 @@ class DeepseekV2ForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = DeepseekV2Model(config, quant_config)
self.model = DeepseekV2Model(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
Expand Down

0 comments on commit 688606f

Please sign in to comment.