From 2299e5d84d456db95a087a099b8f40bdf6d2b73c Mon Sep 17 00:00:00 2001 From: Shamane Siri Date: Wed, 16 Oct 2024 16:06:08 +1300 Subject: [PATCH 1/2] fixed the device issue relaute to the matmul --- download_weights.py | 1 + entropix/torch_main.py | 8 +++++++- entropix/torch_model.py | 33 +++++++++++++++++++++++++-------- entropix/torch_weights.py | 2 +- requirements.txt | 21 +++++++++++++++++++++ 5 files changed, 55 insertions(+), 10 deletions(-) create mode 100644 requirements.txt diff --git a/download_weights.py b/download_weights.py index b8d19d3..b3d74a7 100644 --- a/download_weights.py +++ b/download_weights.py @@ -81,6 +81,7 @@ def main(model_id: str, out_dir: Path): with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): token = t_path.read_text().strip() hf_model = AutoModelForCausalLM.from_pretrained(model_id,torch_dtype=torch.bfloat16, offload_folder="/tmp/offload", token=token) + with torch.no_grad(): state_dict = hf_model.state_dict() for hf_name, param in state_dict.items(): diff --git a/entropix/torch_main.py b/entropix/torch_main.py index 64cf1ee..a0bfa59 100644 --- a/entropix/torch_main.py +++ b/entropix/torch_main.py @@ -1,5 +1,11 @@ from typing import NamedTuple, Optional, Tuple +import sys +from pathlib import Path + +# Add the parent directory of 'entropix' to the Python path +sys.path.append(str(Path(__file__).parent.parent)) + import torch import torch.nn.functional as F @@ -112,6 +118,7 @@ def generate(xfmr_weights, model_params, tokens): next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True).to(torch.int32) gen_tokens = next_token print(tokenizer.decode([next_token.item()]), end='', flush=True) + exit() cur_pos = seqlen stop = torch.tensor([128001, 128008, 128009], device=device, dtype=torch.int32) while cur_pos < 8192: @@ -123,7 +130,6 @@ def generate(xfmr_weights, model_params, tokens): if torch.isin(next_token, stop).any(): break - print(prompt) generate(xfmr_weights, model_params, raw_tokens1) if __name__ == '__main__': diff --git a/entropix/torch_model.py b/entropix/torch_model.py index 0ebb3e9..fb64bda 100644 --- a/entropix/torch_model.py +++ b/entropix/torch_model.py @@ -39,25 +39,42 @@ def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor def attention(x: torch.Tensor, layer_weights: LayerWeights, model_params, cur_pos: int, layer_idx: int, freqs_cis: torch.Tensor, kvcache: KVCache, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, KVCache, torch.Tensor]: bsz, _, _ = x.shape n_rep = model_params.n_local_heads // model_params.n_local_kv_heads - xq = F.linear(x, layer_weights.wq).reshape(bsz, -1, model_params.n_local_heads, model_params.head_dim) - xk = F.linear(x, layer_weights.wk).reshape(bsz, -1, model_params.n_local_kv_heads, model_params.head_dim) - xv = F.linear(x, layer_weights.wv).reshape(bsz, -1, model_params.n_local_kv_heads, model_params.head_dim) + + # Ensure x is on the correct device + x = x.to(device) + + xq = F.linear(x, layer_weights.wq.to(device)).reshape(bsz, -1, model_params.n_local_heads, model_params.head_dim) + xk = F.linear(x, layer_weights.wk.to(device)).reshape(bsz, -1, model_params.n_local_kv_heads, model_params.head_dim) + xv = F.linear(x, layer_weights.wv.to(device)).reshape(bsz, -1, model_params.n_local_kv_heads, model_params.head_dim) + + # Ensure freqs_cis is on the correct device + freqs_cis = freqs_cis.to(device) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) keys, values, kvcache = kvcache.update(xk, xv, layer_idx, cur_pos, n_rep) + xq = torch.permute(xq, (0, 2, 1, 3)) # (bs, n_heads, seqlen, head_dim) keys = torch.permute(keys, (0, 2, 3, 1)) # (bs, n_heads, head_dim, cache_len + seqlen) values = torch.permute(values, (0, 2, 1, 3)) # (bs, n_heads, cache_len + seqlen, head_dim) + + # Ensure all tensors are on the same device before matmul + xq, keys, values = xq.to(device), keys.to(device), values.to(device) + scores = torch.matmul(xq, keys) pre_scores = scores / math.sqrt(model_params.head_dim) scores = pre_scores.to(torch.float32) # Always do attention softmax at float32 - if cur_pos == 0: + + if cur_pos == 0 and attn_mask is not None: + attn_mask = attn_mask.to(device) scores = scores + attn_mask - mask = torch.where(scores != 0.0, scores, DEFAULT_MASK_VALUE) - padded_logits = torch.where((mask >= DEFAULT_MASK_VALUE * 0.5), scores, DEFAULT_MASK_VALUE) + + mask = torch.where(scores != 0.0, scores, torch.tensor(DEFAULT_MASK_VALUE, device=device)) + padded_logits = torch.where((mask >= DEFAULT_MASK_VALUE * 0.5), scores, torch.tensor(DEFAULT_MASK_VALUE, device=device)) scores = F.softmax(padded_logits, dim=-1).to(torch.float32) output = torch.matmul(scores, values) output = output.transpose(1, 2).reshape(xq.shape[0], xq.shape[2], -1) - out = F.linear(output, layer_weights.wo) + out = F.linear(output, layer_weights.wo.to(device)) + return out, kvcache, pre_scores def feed_forward(x: torch.Tensor, layer_weights: LayerWeights) -> torch.Tensor: @@ -77,4 +94,4 @@ def xfmr(xfmr_weights: XfmrWeights, model_params: ModelParams, tokens: torch.Ten h = h + h_attn h = h + feed_forward(rms_norm(h, xfmr_weights.layer_weights[i].ffn_norm), xfmr_weights.layer_weights[i]) logits = F.linear(rms_norm(h, xfmr_weights.norm), xfmr_weights.output) - return logits, kvcache, scores, attn_stats \ No newline at end of file + return logits, kvcache, scores, attn_stats diff --git a/entropix/torch_weights.py b/entropix/torch_weights.py index ccbf187..14b3417 100644 --- a/entropix/torch_weights.py +++ b/entropix/torch_weights.py @@ -72,7 +72,7 @@ def load_weights(ckpt_dir: Path = Path('weights/1B-Instruct'), n_layers: int = 1 ffn_norm=w[f'layers.{i}.ffn_norm.weight'], attention_norm=w[f'layers.{i}.attention_norm.weight'], )) - + xfmr_weights = XfmrWeights( tok_embeddings=w['tok_embeddings.weight'], norm=w['norm.weight'], diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0082cda --- /dev/null +++ b/requirements.txt @@ -0,0 +1,21 @@ +# Main dependencies +flax==0.9.0 +tiktoken==0.4.0 +pydantic==2.9.2 +blobfile==3.0.0 +ml-dtypes==0.5.0 +rich==13.8.1 +torch==2.4.1 ; platform_system != 'Darwin' +chex==0.1.87 +tyro==0.8.11 + +# Dev dependencies +pytest==8.3.2 +ruff==0.6.2 +transformers==4.45.1 +huggingface-hub[cli]==0.25.1 + +# Test dependencies +fairscale==0.4.13 + +torch \ No newline at end of file From 1a3eaf1087e729f4ec320ba2a35e750b20753744 Mon Sep 17 00:00:00 2001 From: Shamane Siri Date: Wed, 16 Oct 2024 17:34:38 +1300 Subject: [PATCH 2/2] execution time. --- entropix/torch_main.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/entropix/torch_main.py b/entropix/torch_main.py index a0bfa59..712b926 100644 --- a/entropix/torch_main.py +++ b/entropix/torch_main.py @@ -1,5 +1,5 @@ from typing import NamedTuple, Optional, Tuple - +import time import sys from pathlib import Path @@ -107,6 +107,7 @@ def main(): def generate(xfmr_weights, model_params, tokens): + start_time = time.time() gen_tokens = None cur_pos = 0 tokens = torch.tensor([tokens], dtype=torch.long).to(device) @@ -118,7 +119,6 @@ def generate(xfmr_weights, model_params, tokens): next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True).to(torch.int32) gen_tokens = next_token print(tokenizer.decode([next_token.item()]), end='', flush=True) - exit() cur_pos = seqlen stop = torch.tensor([128001, 128008, 128009], device=device, dtype=torch.int32) while cur_pos < 8192: @@ -129,7 +129,12 @@ def generate(xfmr_weights, model_params, tokens): print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True) if torch.isin(next_token, stop).any(): break - + print("\n\n") + print("============") + end_time = time.time() + elapsed_time = end_time - start_time + print(f"\n\nGeneration completed in {elapsed_time:.2f} seconds.") + generate(xfmr_weights, model_params, raw_tokens1) if __name__ == '__main__':