13
13
from mlc_llm import op as op_ext
14
14
from mlc_llm .nn import PagedKVCache , RopeMode
15
15
from mlc_llm .support import logging
16
+ from mlc_llm .support import tensor_parallel as tp
16
17
from mlc_llm .support .config import ConfigBase
17
18
from mlc_llm .support .style import bold
18
19
@@ -39,6 +40,7 @@ class BaichuanConfig(ConfigBase): # pylint: disable=too-many-instance-attribute
39
40
prefill_chunk_size : int = 0
40
41
tensor_parallel_shards : int = 1
41
42
max_batch_size : int = 1
43
+ head_dim : int = 0
42
44
kwargs : Dict [str , Any ] = dataclasses .field (default_factory = dict )
43
45
44
46
def __post_init__ (self ):
@@ -59,6 +61,9 @@ def __post_init__(self):
59
61
"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is "
60
62
"provided in `config.json`."
61
63
)
64
+ if self .head_dim == 0 :
65
+ self .head_dim = self .hidden_size // self .num_attention_heads
66
+ assert self .head_dim * self .num_attention_heads == self .hidden_size
62
67
if self .prefill_chunk_size == 0 :
63
68
logger .info (
64
69
"%s defaults to %s (%d)" ,
@@ -84,11 +89,9 @@ def __post_init__(self):
84
89
class BaichuanAttention (nn .Module ): # pylint: disable=too-many-instance-attributes
85
90
def __init__ (self , config : BaichuanConfig ):
86
91
self .hidden_size = config .hidden_size
87
- self .num_heads = config .num_attention_heads
88
- self .head_dim = self .hidden_size // self .num_heads
89
- self .max_position_embeddings = config .context_window_size
90
-
91
- self .W_pack = nn .Linear (self .hidden_size , 3 * self .hidden_size , bias = False )
92
+ self .num_heads = config .num_attention_heads // config .tensor_parallel_shards
93
+ self .head_dim = config .head_dim
94
+ self .W_pack = nn .Linear (self .hidden_size , 3 * self .num_heads * self .head_dim , bias = False )
92
95
self .o_proj = nn .Linear (self .num_heads * self .head_dim , self .hidden_size , bias = False )
93
96
94
97
def forward (self , hidden_states : Tensor , paged_kv_cache : PagedKVCache , layer_id : int ):
@@ -105,12 +108,13 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:
105
108
106
109
class BaichuanMLP (nn .Module ):
107
110
def __init__ (self , config : BaichuanConfig ):
111
+ self .intermediate_size = config .intermediate_size // config .tensor_parallel_shards
108
112
self .gate_up_proj = nn .Linear (
109
113
in_features = config .hidden_size ,
110
- out_features = 2 * config .intermediate_size ,
114
+ out_features = 2 * self .intermediate_size ,
111
115
bias = False ,
112
116
)
113
- self .down_proj = nn .Linear (config .intermediate_size , config .hidden_size , bias = False )
117
+ self .down_proj = nn .Linear (self .intermediate_size , config .hidden_size , bias = False )
114
118
115
119
def forward (self , x ):
116
120
concat_x1_x2 = self .gate_up_proj (x )
@@ -126,13 +130,41 @@ def __init__(self, config: BaichuanConfig):
126
130
self .input_layernorm = nn .RMSNorm (config .hidden_size , - 1 , norm_eps , bias = False )
127
131
self .post_attention_layernorm = nn .RMSNorm (config .hidden_size , - 1 , norm_eps , bias = False )
128
132
133
+ def _set_tp ():
134
+ def _set (layer , hint ):
135
+ layer .attrs ["shard_strategy" ] = hint
136
+
137
+ hd = config .head_dim
138
+ q = self .self_attn .num_heads * hd
139
+ k = self .self_attn .num_heads * hd
140
+ v = self .self_attn .num_heads * hd
141
+ i = self .mlp .intermediate_size
142
+ _set (
143
+ self .self_attn .W_pack .weight ,
144
+ tp .ShardSingleDim ("_shard_qkv_weight" , dim = 0 , segs = [q , k , v ]),
145
+ )
146
+ _set (self .self_attn .o_proj .weight , tp .ShardSingleDim ("_shard_o" , dim = 1 ))
147
+ _set (
148
+ self .mlp .gate_up_proj .weight ,
149
+ tp .ShardSingleDim ("_shard_mlp_gate_up" , segs = [i , i ], dim = 0 ),
150
+ )
151
+ _set (self .mlp .down_proj .weight , tp .ShardSingleDim ("_shard_mlp_down_proj" , dim = 1 ))
152
+
153
+ self .tensor_parallel_shards = config .tensor_parallel_shards
154
+ _set_tp ()
155
+
129
156
def forward (self , hidden_states : Tensor , paged_kv_cache : PagedKVCache , layer_id : int ):
130
157
out = self .self_attn (self .input_layernorm (hidden_states ), paged_kv_cache , layer_id )
131
- hidden_states = out + hidden_states
158
+ hidden_states = self . _apply_residual ( out , residual = hidden_states )
132
159
out = self .mlp (self .post_attention_layernorm (hidden_states ))
133
- hidden_states = out + hidden_states
160
+ hidden_states = self . _apply_residual ( out , residual = hidden_states )
134
161
return hidden_states
135
162
163
+ def _apply_residual (self , out , residual ):
164
+ if self .tensor_parallel_shards > 1 :
165
+ return op .ccl_allreduce (out , "sum" ) + residual
166
+ return out + residual
167
+
136
168
137
169
class BaichuanModel (nn .Module ):
138
170
def __init__ (self , config : BaichuanConfig ):
@@ -159,7 +191,7 @@ def __init__(self, config: BaichuanConfig):
159
191
self .num_hidden_layers = config .num_hidden_layers
160
192
self .hidden_size = config .hidden_size
161
193
self .num_attention_heads = config .num_attention_heads
162
- self .head_dim = self . hidden_size // self . num_attention_heads
194
+ self .head_dim = config . head_dim
163
195
self .vocab_size = config .vocab_size
164
196
self .rope_theta = 10000
165
197
self .tensor_parallel_shards = config .tensor_parallel_shards
@@ -187,6 +219,8 @@ def batch_forward(
187
219
return logits
188
220
189
221
def embed (self , input_ids : Tensor ):
222
+ if self .tensor_parallel_shards > 1 :
223
+ input_ids = op .ccl_broadcast_from_worker0 (input_ids )
190
224
return self .model .embed_tokens (input_ids )
191
225
192
226
def prefill (self , input_embed : Tensor , paged_kv_cache : PagedKVCache ):
@@ -215,6 +249,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
215
249
def batch_prefill (
216
250
self , input_embeds : Tensor , logit_positions : Tensor , paged_kv_cache : PagedKVCache
217
251
):
252
+ if self .tensor_parallel_shards > 1 :
253
+ logit_positions = op .ccl_broadcast_from_worker0 (logit_positions )
218
254
logits = self .batch_forward (input_embeds , paged_kv_cache , logit_positions )
219
255
return logits , paged_kv_cache
220
256
0 commit comments