Skip to content

Commit a0c0f21

Browse files
authored
[SLM] Baichuan Multi-GPU support (#2037)
This PR enables TP function of Baichuan2 model.
1 parent 5ebcda1 commit a0c0f21

File tree

1 file changed

+46
-10
lines changed

1 file changed

+46
-10
lines changed

python/mlc_llm/model/baichuan/baichuan_model.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from mlc_llm import op as op_ext
1414
from mlc_llm.nn import PagedKVCache, RopeMode
1515
from mlc_llm.support import logging
16+
from mlc_llm.support import tensor_parallel as tp
1617
from mlc_llm.support.config import ConfigBase
1718
from mlc_llm.support.style import bold
1819

@@ -39,6 +40,7 @@ class BaichuanConfig(ConfigBase): # pylint: disable=too-many-instance-attribute
3940
prefill_chunk_size: int = 0
4041
tensor_parallel_shards: int = 1
4142
max_batch_size: int = 1
43+
head_dim: int = 0
4244
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
4345

4446
def __post_init__(self):
@@ -59,6 +61,9 @@ def __post_init__(self):
5961
"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is "
6062
"provided in `config.json`."
6163
)
64+
if self.head_dim == 0:
65+
self.head_dim = self.hidden_size // self.num_attention_heads
66+
assert self.head_dim * self.num_attention_heads == self.hidden_size
6267
if self.prefill_chunk_size == 0:
6368
logger.info(
6469
"%s defaults to %s (%d)",
@@ -84,11 +89,9 @@ def __post_init__(self):
8489
class BaichuanAttention(nn.Module): # pylint: disable=too-many-instance-attributes
8590
def __init__(self, config: BaichuanConfig):
8691
self.hidden_size = config.hidden_size
87-
self.num_heads = config.num_attention_heads
88-
self.head_dim = self.hidden_size // self.num_heads
89-
self.max_position_embeddings = config.context_window_size
90-
91-
self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
92+
self.num_heads = config.num_attention_heads // config.tensor_parallel_shards
93+
self.head_dim = config.head_dim
94+
self.W_pack = nn.Linear(self.hidden_size, 3 * self.num_heads * self.head_dim, bias=False)
9295
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
9396

9497
def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
@@ -105,12 +108,13 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:
105108

106109
class BaichuanMLP(nn.Module):
107110
def __init__(self, config: BaichuanConfig):
111+
self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards
108112
self.gate_up_proj = nn.Linear(
109113
in_features=config.hidden_size,
110-
out_features=2 * config.intermediate_size,
114+
out_features=2 * self.intermediate_size,
111115
bias=False,
112116
)
113-
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
117+
self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)
114118

115119
def forward(self, x):
116120
concat_x1_x2 = self.gate_up_proj(x)
@@ -126,13 +130,41 @@ def __init__(self, config: BaichuanConfig):
126130
self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, norm_eps, bias=False)
127131
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, norm_eps, bias=False)
128132

133+
def _set_tp():
134+
def _set(layer, hint):
135+
layer.attrs["shard_strategy"] = hint
136+
137+
hd = config.head_dim
138+
q = self.self_attn.num_heads * hd
139+
k = self.self_attn.num_heads * hd
140+
v = self.self_attn.num_heads * hd
141+
i = self.mlp.intermediate_size
142+
_set(
143+
self.self_attn.W_pack.weight,
144+
tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]),
145+
)
146+
_set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1))
147+
_set(
148+
self.mlp.gate_up_proj.weight,
149+
tp.ShardSingleDim("_shard_mlp_gate_up", segs=[i, i], dim=0),
150+
)
151+
_set(self.mlp.down_proj.weight, tp.ShardSingleDim("_shard_mlp_down_proj", dim=1))
152+
153+
self.tensor_parallel_shards = config.tensor_parallel_shards
154+
_set_tp()
155+
129156
def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
130157
out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id)
131-
hidden_states = out + hidden_states
158+
hidden_states = self._apply_residual(out, residual=hidden_states)
132159
out = self.mlp(self.post_attention_layernorm(hidden_states))
133-
hidden_states = out + hidden_states
160+
hidden_states = self._apply_residual(out, residual=hidden_states)
134161
return hidden_states
135162

163+
def _apply_residual(self, out, residual):
164+
if self.tensor_parallel_shards > 1:
165+
return op.ccl_allreduce(out, "sum") + residual
166+
return out + residual
167+
136168

137169
class BaichuanModel(nn.Module):
138170
def __init__(self, config: BaichuanConfig):
@@ -159,7 +191,7 @@ def __init__(self, config: BaichuanConfig):
159191
self.num_hidden_layers = config.num_hidden_layers
160192
self.hidden_size = config.hidden_size
161193
self.num_attention_heads = config.num_attention_heads
162-
self.head_dim = self.hidden_size // self.num_attention_heads
194+
self.head_dim = config.head_dim
163195
self.vocab_size = config.vocab_size
164196
self.rope_theta = 10000
165197
self.tensor_parallel_shards = config.tensor_parallel_shards
@@ -187,6 +219,8 @@ def batch_forward(
187219
return logits
188220

189221
def embed(self, input_ids: Tensor):
222+
if self.tensor_parallel_shards > 1:
223+
input_ids = op.ccl_broadcast_from_worker0(input_ids)
190224
return self.model.embed_tokens(input_ids)
191225

192226
def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
@@ -215,6 +249,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
215249
def batch_prefill(
216250
self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache
217251
):
252+
if self.tensor_parallel_shards > 1:
253+
logit_positions = op.ccl_broadcast_from_worker0(logit_positions)
218254
logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)
219255
return logits, paged_kv_cache
220256

0 commit comments

Comments
 (0)