Skip to content

Commit

Permalink
[SLM] Qwen2 Multi-GPU support (#1985)
Browse files Browse the repository at this point in the history
* Update qwen2_model.py

* fix lint issue

* fix lint issue

* fix lint issue
  • Loading branch information
tlopex authored Mar 25, 2024
1 parent 1c8b72e commit ab9fa81
Showing 1 changed file with 50 additions and 21 deletions.
71 changes: 50 additions & 21 deletions python/mlc_llm/model/qwen2/qwen2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
from mlc_llm.support.style import bold

Expand All @@ -35,6 +36,7 @@ class QWen2Config(ConfigBase): # pylint: disable=too-many-instance-attributes
context_window_size: int = 0
prefill_chunk_size: int = 0
tensor_parallel_shards: int = 1
head_dim: int = 0
dtype: str = "float32"
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

Expand All @@ -56,6 +58,9 @@ def __post_init__(self):
"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is "
"provided in `config.json`."
)
if self.head_dim == 0:
self.head_dim = self.hidden_size // self.num_attention_heads
assert self.head_dim * self.num_attention_heads == self.hidden_size
if self.prefill_chunk_size == 0:
logger.info(
"%s defaults to %s (%d)",
Expand All @@ -80,29 +85,19 @@ def __post_init__(self):

class QWen2Attention(nn.Module): # pylint: disable=too-many-instance-attributes
def __init__(self, config: QWen2Config):
head_dim = config.hidden_size // config.num_attention_heads
self.head_dim = config.head_dim
self.num_attention_heads = config.num_attention_heads // config.tensor_parallel_shards
self.num_key_value_heads = config.num_key_value_heads // config.tensor_parallel_shards
self.rope_theta = config.rope_theta

self.c_attn = nn.Linear(
in_features=config.hidden_size,
out_features=(2 * config.num_key_value_heads + config.num_attention_heads) * head_dim,
out_features=(2 * self.num_key_value_heads + self.num_attention_heads) * self.head_dim,
bias=True,
)
self.o_proj = nn.Linear(
config.num_attention_heads * head_dim, config.hidden_size, bias=False
)
# KV cache for single sequence
self.k_cache = nn.KVCache(
config.context_window_size, [config.num_key_value_heads, head_dim]
self.num_attention_heads * self.head_dim, config.hidden_size, bias=False
)
self.v_cache = nn.KVCache(
config.context_window_size, [config.num_attention_heads, head_dim]
)

self.hidden_size = config.hidden_size
self.head_dim = head_dim
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.rope_theta = config.rope_theta

def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
d, h_q, h_kv = self.head_dim, self.num_attention_heads, self.num_key_value_heads
Expand All @@ -128,8 +123,9 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:

class QWen2MLP(nn.Module):
def __init__(self, config: QWen2Config):
self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards
self.gate_up_proj = nn.Linear(config.hidden_size, 2 * self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x: Tensor):
Expand All @@ -147,15 +143,46 @@ def __init__(self, config: QWen2Config):
config.hidden_size, -1, config.rms_norm_eps, bias=False
)

def _set_tp():
def _set(layer, hint):
layer.attrs["shard_strategy"] = hint

hd = config.head_dim
q = self.self_attn.num_attention_heads * hd
k = self.self_attn.num_key_value_heads * hd
v = self.self_attn.num_key_value_heads * hd
i = self.mlp.intermediate_size
_set(
self.self_attn.c_attn.weight,
tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]),
)
_set(
self.self_attn.c_attn.bias,
tp.ShardSingleDim("_shard_qkv_bias", dim=0, segs=[q, k, v]),
)
_set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1))
_set(
self.mlp.gate_up_proj.weight, tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=0)
)
_set(self.mlp.down_proj.weight, tp.ShardSingleDim("_shard_mlp_down", dim=1))

self.tensor_parallel_shards = config.tensor_parallel_shards
_set_tp()

def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
out = self.input_layernorm(hidden_states)
out = self.self_attn(out, paged_kv_cache, layer_id)
hidden_states = out + hidden_states
hidden_states = self._apply_residual(out, residual=hidden_states)
out = self.post_attention_layernorm(hidden_states)
out = self.mlp(out)
hidden_states = out + hidden_states
hidden_states = self._apply_residual(out, residual=hidden_states)
return hidden_states

def _apply_residual(self, out, residual):
if self.tensor_parallel_shards > 1:
return op.ccl_allreduce(out, "sum") + residual
return out + residual


class QWen2Model(nn.Module):
def __init__(self, config: QWen2Config):
Expand Down Expand Up @@ -187,7 +214,7 @@ def __init__(self, config: QWen2Config):
self.rope_theta = config.rope_theta
self.vocab_size = config.vocab_size
self.tensor_parallel_shards = config.tensor_parallel_shards
self.head_dim = config.hidden_size // config.num_attention_heads
self.head_dim = config.head_dim

def to(self, dtype: Optional[str] = None):
super().to(dtype=dtype)
Expand All @@ -211,6 +238,8 @@ def batch_forward(
return logits

def embed(self, input_ids: Tensor):
if self.tensor_parallel_shards > 1:
input_ids = op.ccl_broadcast_from_worker0(input_ids)
return self.model.embed_tokens(input_ids)

def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
Expand Down

0 comments on commit ab9fa81

Please sign in to comment.