Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[example] Added (hacky) Grok1 support #171

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mixtral-moe/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# Grok-1 Support
```
export MODEL_REPO=hpcai-tech/grok-1
python scripts/download.py --repo_id $MODEL_REPO
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8

TOKENIZERS_PARALLELISM=false ENABLE_INTRA_NODE_COMM=1 time torchrun --standalone --nproc_per_node=8 generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --compile --compile_prefill
```
# Mixtral 8x7B
[Mixtral 8x7B](https://mistral.ai/news/mixtral-of-experts/) is a high-quality sparse mixture of experts (MoE) model that matches or beats GPT3.5 on most benchmarks. This repro is a simple and efficient PyTorch native implementation of Mixtral 8x7B.

Expand Down
10 changes: 6 additions & 4 deletions mixtral-moe/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def generate(
def encode_tokens(tokenizer, string, bos=True, device='cuda'):
tokens = tokenizer.encode(string)
if bos:
tokens = [tokenizer.bos_id()] + tokens
tokens = [tokenizer.bos_token_id] + tokens
return torch.tensor(tokens, dtype=torch.int, device=device)

def _load_model(checkpoint_path, device, precision, use_tp):
Expand Down Expand Up @@ -174,7 +174,7 @@ def main(
"""
assert checkpoint_path.is_file(), checkpoint_path

tokenizer_path = checkpoint_path.parent / "tokenizer.model"
tokenizer_path = checkpoint_path.parent / "tokenizer.json"
assert tokenizer_path.is_file(), str(tokenizer_path)

global print
Expand All @@ -196,7 +196,9 @@ def main(
device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")

tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
# tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("hpcai-tech/grok-1", trust_remote_code=True)
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
prompt_length = encoded.size(0)

Expand Down Expand Up @@ -235,7 +237,7 @@ def callback(x):
if done_generating:
return
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
if x.item() == tokenizer.eos_id():
if x.item() == tokenizer.eos_token_id:
done_generating = True
if len(buffer) == 4 or done_generating:
print(''.join(buffer), end='', flush=True)
Expand Down
30 changes: 23 additions & 7 deletions mixtral-moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,14 @@ def from_name(cls, name: str):
assert len(config) == 1, name
return cls(**transformer_configs[config[0]])

attn_output_multiplier = 0.08838834764831845
embedding_multiplier_scale = 78.38367176906169
output_multiplier_scale = 0.5773502691896257
max_attn_val = 30.0

transformer_configs = {
"Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2),
"grok-1": dict(vocab_size=131072, block_size=8192, n_layer=64, n_head=48, n_local_heads=8, dim=6144, intermediate_size=32768, rope_base=1000000.0, num_experts=8, num_activated_experts=2),
}

class KVCache(nn.Module):
Expand Down Expand Up @@ -106,11 +111,13 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
mask = self.causal_mask[None, None, input_pos]
freqs_cis = self.freqs_cis[input_pos]
x = self.tok_embeddings(idx)
x *= embedding_multiplier_scale

for i, layer in enumerate(self.layers):
x = layer(x, input_pos, freqs_cis, mask)
x = self.norm(x)
logits = self.output(x)
logits *= output_multiplier_scale
return logits

@classmethod
Expand All @@ -123,12 +130,14 @@ def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.attention = Attention(config)
self.block_sparse_moe = MOEFeedForward(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
self.pre_moe_norm = RMSNorm(config.dim, config.norm_eps)
self.post_moe_norm = RMSNorm(config.dim, config.norm_eps)
self.post_attn_norm = RMSNorm(config.dim, config.norm_eps)
self.pre_attn_norm = RMSNorm(config.dim, config.norm_eps)

def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
out = h + self.block_sparse_moe(self.ffn_norm(h))
h = x + self.post_attn_norm(self.attention(self.pre_attn_norm(x), freqs_cis, mask, input_pos))
out = h + self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(h)))
return out


Expand Down Expand Up @@ -160,7 +169,8 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
qkv = self.wqkv(x)
q, k, v = qkv.split([self.dim, kv_size, kv_size], dim=-1)

q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
Expand All @@ -176,7 +186,13 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
attn_weights = torch.matmul(q, k.transpose(2, 3)).to(torch.float32)
attn_weights = attn_weights * attn_output_multiplier
attn_weights = max_attn_val * F.tanh(attn_weights / max_attn_val)
attn_weights += torch.where(mask, 0, -float("inf"))
attn_weights = F.softmax(attn_weights, dim=-1).to(q.dtype)
y = torch.matmul(attn_weights, v)
# y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

Expand All @@ -195,7 +211,7 @@ def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
w1_weights = self.w1[expert_indices] # [T, A, D, D]
w3_weights = self.w3[expert_indices] # [T, A, D, D]
w2_weights = self.w2[expert_indices] # [T, A, D, D]
x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights))
x1 = F.gelu(torch.einsum('ti,taoi -> tao', x, w1_weights))
x3 = torch.einsum('ti, taoi -> tao', x, w3_weights)
expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights)
return expert_outs
Expand Down
64 changes: 43 additions & 21 deletions mixtral-moe/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,42 +32,52 @@ def convert_hf_checkpoint(
print(f"Model config {config.__dict__}")

weight_map = {
"tok_embeddings.weight": "tok_embeddings.weight",
"layers.{}.attention.wq.weight": "layers.{}.attention.wq.weight",
"layers.{}.attention.wk.weight": "layers.{}.attention.wk.weight",
"layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight",
"layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight",
"layers.{}.block_sparse_moe.w1": "layers.{}.block_sparse_moe.cond_ffn.w1",
"layers.{}.block_sparse_moe.w2": "layers.{}.block_sparse_moe.cond_ffn.w2",
"layers.{}.block_sparse_moe.w3": "layers.{}.block_sparse_moe.cond_ffn.w3",
"layers.{}.block_sparse_moe.gate.weight": "layers.{}.block_sparse_moe.gate.weight",
"layers.{}.attention_norm.weight": "layers.{}.attention_norm.weight",
"layers.{}.ffn_norm.weight": "layers.{}.ffn_norm.weight",
"norm.weight": "norm.weight",
"output.weight": "output.weight",
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.attn.o_proj.weight": "layers.{}.attention.wo.weight",
# "layers.{}.attention.wk.weight": "layers.{}.attention.wk.weight",
# "layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight",
# "layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight",
"model.layers.{}.moe_block.experts.{}.linear.weight": "layers.{}.block_sparse_moe.cond_ffn.w1.{}",
"model.layers.{}.moe_block.experts.{}.linear_1.weight": "layers.{}.block_sparse_moe.cond_ffn.w2.{}",
"model.layers.{}.moe_block.experts.{}.linear_v.weight": "layers.{}.block_sparse_moe.cond_ffn.w3.{}",
"model.layers.{}.moe_block.gate.weight": "layers.{}.block_sparse_moe.gate.weight",
"model.layers.{}.pre_attn_norm.scale": "layers.{}.pre_attn_norm.weight",
"model.layers.{}.post_attn_norm.scale": "layers.{}.post_attn_norm.weight",
"model.layers.{}.pre_moe_norm.scale": "layers.{}.pre_moe_norm.weight",
"model.layers.{}.post_moe_norm.scale": "layers.{}.post_moe_norm.weight",
"model.norm.scale": "norm.weight",
"lm_head.weight": "output.weight",
}

pt_files = glob.glob(str(checkpoint_dir / "*.pt"))
pt_files = glob.glob(str(checkpoint_dir / "*.bin"))

merged_result = {}
for file in sorted(pt_files):
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
merged_result.update(state_dict)
final_result = {}
for key, value in merged_result.items():
for key, value in list(merged_result.items()):
if "layers" in key:
abstract_key = re.sub(r'.(\d+).', '.{}.', key)
layer_num = re.search(r'\d+', key).group(0)
abstract_key = re.sub(r'\.(\d+)\.', '.{}.', key)
nums = re.findall(r'\d+', key)
if abstract_key not in weight_map:
continue
new_key = weight_map[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
new_key = new_key.format(*nums)
else:
if key not in weight_map:
continue
new_key = weight_map[key]

final_result[new_key] = value
del merged_result[key]

for key in tuple(final_result.keys()):
print(key)
if "wq" in key:
q = final_result[key]
k = final_result[key.replace("wq", "wk")]
Expand All @@ -77,9 +87,21 @@ def convert_hf_checkpoint(
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
elif "w1" in key or "w3" in key:
final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).contiguous()
if not key.endswith('0'):
continue
full_keys = [key[:-1] + str(i) for i in range(8)]
results = [final_result[k] for k in full_keys]
final_result[key[:-2]] = torch.stack(results, dim=0)
for k in full_keys:
del final_result[k]
elif "w2" in key:
final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).permute(0, 2, 1).contiguous()
if not key.endswith('0'):
continue
full_keys = [key[:-1] + str(i) for i in range(8)]
results = [final_result[k] for k in full_keys]
final_result[key[:-2]] = torch.stack(results, dim=0)
for k in full_keys:
del final_result[k]
elif "gate" in key:
final_result[key] = final_result[key].contiguous()

Expand Down