Skip to content

Commit

Permalink
[SLM] Baichuan Multi-GPU support (#2037)
Browse files Browse the repository at this point in the history
This PR enables TP function of Baichuan2 model.
  • Loading branch information
tlopex authored Mar 28, 2024
1 parent 5ebcda1 commit a0c0f21
Showing 1 changed file with 46 additions and 10 deletions.
56 changes: 46 additions & 10 deletions python/mlc_llm/model/baichuan/baichuan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
from mlc_llm.support.style import bold

Expand All @@ -39,6 +40,7 @@ class BaichuanConfig(ConfigBase): # pylint: disable=too-many-instance-attribute
prefill_chunk_size: int = 0
tensor_parallel_shards: int = 1
max_batch_size: int = 1
head_dim: int = 0
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
Expand All @@ -59,6 +61,9 @@ def __post_init__(self):
"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is "
"provided in `config.json`."
)
if self.head_dim == 0:
self.head_dim = self.hidden_size // self.num_attention_heads
assert self.head_dim * self.num_attention_heads == self.hidden_size
if self.prefill_chunk_size == 0:
logger.info(
"%s defaults to %s (%d)",
Expand All @@ -84,11 +89,9 @@ def __post_init__(self):
class BaichuanAttention(nn.Module): # pylint: disable=too-many-instance-attributes
def __init__(self, config: BaichuanConfig):
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.context_window_size

self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
self.num_heads = config.num_attention_heads // config.tensor_parallel_shards
self.head_dim = config.head_dim
self.W_pack = nn.Linear(self.hidden_size, 3 * self.num_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
Expand All @@ -105,12 +108,13 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:

class BaichuanMLP(nn.Module):
def __init__(self, config: BaichuanConfig):
self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards
self.gate_up_proj = nn.Linear(
in_features=config.hidden_size,
out_features=2 * config.intermediate_size,
out_features=2 * self.intermediate_size,
bias=False,
)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)

def forward(self, x):
concat_x1_x2 = self.gate_up_proj(x)
Expand All @@ -126,13 +130,41 @@ def __init__(self, config: BaichuanConfig):
self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, norm_eps, bias=False)
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, norm_eps, bias=False)

def _set_tp():
def _set(layer, hint):
layer.attrs["shard_strategy"] = hint

hd = config.head_dim
q = self.self_attn.num_heads * hd
k = self.self_attn.num_heads * hd
v = self.self_attn.num_heads * hd
i = self.mlp.intermediate_size
_set(
self.self_attn.W_pack.weight,
tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]),
)
_set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1))
_set(
self.mlp.gate_up_proj.weight,
tp.ShardSingleDim("_shard_mlp_gate_up", segs=[i, i], dim=0),
)
_set(self.mlp.down_proj.weight, tp.ShardSingleDim("_shard_mlp_down_proj", dim=1))

self.tensor_parallel_shards = config.tensor_parallel_shards
_set_tp()

def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id)
hidden_states = out + hidden_states
hidden_states = self._apply_residual(out, residual=hidden_states)
out = self.mlp(self.post_attention_layernorm(hidden_states))
hidden_states = out + hidden_states
hidden_states = self._apply_residual(out, residual=hidden_states)
return hidden_states

def _apply_residual(self, out, residual):
if self.tensor_parallel_shards > 1:
return op.ccl_allreduce(out, "sum") + residual
return out + residual


class BaichuanModel(nn.Module):
def __init__(self, config: BaichuanConfig):
Expand All @@ -159,7 +191,7 @@ def __init__(self, config: BaichuanConfig):
self.num_hidden_layers = config.num_hidden_layers
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_attention_heads
self.head_dim = config.head_dim
self.vocab_size = config.vocab_size
self.rope_theta = 10000
self.tensor_parallel_shards = config.tensor_parallel_shards
Expand Down Expand Up @@ -187,6 +219,8 @@ def batch_forward(
return logits

def embed(self, input_ids: Tensor):
if self.tensor_parallel_shards > 1:
input_ids = op.ccl_broadcast_from_worker0(input_ids)
return self.model.embed_tokens(input_ids)

def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
Expand Down Expand Up @@ -215,6 +249,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
def batch_prefill(
self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache
):
if self.tensor_parallel_shards > 1:
logit_positions = op.ccl_broadcast_from_worker0(logit_positions)
logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)
return logits, paged_kv_cache

Expand Down

0 comments on commit a0c0f21

Please sign in to comment.