Skip to content

Commit

Permalink
Merge pull request #1 from arcee-ai/kv-cache-explore
Browse files Browse the repository at this point in the history
Macbook pro tested!
  • Loading branch information
shamanez authored Oct 16, 2024
2 parents e55e9a3 + 1a3eaf1 commit c9cc018
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 11 deletions.
1 change: 1 addition & 0 deletions download_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def main(model_id: str, out_dir: Path):
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
token = t_path.read_text().strip()
hf_model = AutoModelForCausalLM.from_pretrained(model_id,torch_dtype=torch.bfloat16, offload_folder="/tmp/offload", token=token)

with torch.no_grad():
state_dict = hf_model.state_dict()
for hf_name, param in state_dict.items():
Expand Down
15 changes: 13 additions & 2 deletions entropix/torch_main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from typing import NamedTuple, Optional, Tuple
import time
import sys
from pathlib import Path

# Add the parent directory of 'entropix' to the Python path
sys.path.append(str(Path(__file__).parent.parent))

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -101,6 +107,7 @@ def main():


def generate(xfmr_weights, model_params, tokens):
start_time = time.time()
gen_tokens = None
cur_pos = 0
tokens = torch.tensor([tokens], dtype=torch.long).to(device)
Expand All @@ -122,8 +129,12 @@ def generate(xfmr_weights, model_params, tokens):
print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True)
if torch.isin(next_token, stop).any():
break

print(prompt)
print("\n\n")
print("============")
end_time = time.time()
elapsed_time = end_time - start_time
print(f"\n\nGeneration completed in {elapsed_time:.2f} seconds.")

generate(xfmr_weights, model_params, raw_tokens1)

if __name__ == '__main__':
Expand Down
33 changes: 25 additions & 8 deletions entropix/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,42 @@ def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
def attention(x: torch.Tensor, layer_weights: LayerWeights, model_params, cur_pos: int, layer_idx: int, freqs_cis: torch.Tensor, kvcache: KVCache, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, KVCache, torch.Tensor]:
bsz, _, _ = x.shape
n_rep = model_params.n_local_heads // model_params.n_local_kv_heads
xq = F.linear(x, layer_weights.wq).reshape(bsz, -1, model_params.n_local_heads, model_params.head_dim)
xk = F.linear(x, layer_weights.wk).reshape(bsz, -1, model_params.n_local_kv_heads, model_params.head_dim)
xv = F.linear(x, layer_weights.wv).reshape(bsz, -1, model_params.n_local_kv_heads, model_params.head_dim)

# Ensure x is on the correct device
x = x.to(device)

xq = F.linear(x, layer_weights.wq.to(device)).reshape(bsz, -1, model_params.n_local_heads, model_params.head_dim)
xk = F.linear(x, layer_weights.wk.to(device)).reshape(bsz, -1, model_params.n_local_kv_heads, model_params.head_dim)
xv = F.linear(x, layer_weights.wv.to(device)).reshape(bsz, -1, model_params.n_local_kv_heads, model_params.head_dim)

# Ensure freqs_cis is on the correct device
freqs_cis = freqs_cis.to(device)

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
keys, values, kvcache = kvcache.update(xk, xv, layer_idx, cur_pos, n_rep)

xq = torch.permute(xq, (0, 2, 1, 3)) # (bs, n_heads, seqlen, head_dim)
keys = torch.permute(keys, (0, 2, 3, 1)) # (bs, n_heads, head_dim, cache_len + seqlen)
values = torch.permute(values, (0, 2, 1, 3)) # (bs, n_heads, cache_len + seqlen, head_dim)

# Ensure all tensors are on the same device before matmul
xq, keys, values = xq.to(device), keys.to(device), values.to(device)

scores = torch.matmul(xq, keys)
pre_scores = scores / math.sqrt(model_params.head_dim)
scores = pre_scores.to(torch.float32) # Always do attention softmax at float32
if cur_pos == 0:

if cur_pos == 0 and attn_mask is not None:
attn_mask = attn_mask.to(device)
scores = scores + attn_mask
mask = torch.where(scores != 0.0, scores, DEFAULT_MASK_VALUE)
padded_logits = torch.where((mask >= DEFAULT_MASK_VALUE * 0.5), scores, DEFAULT_MASK_VALUE)

mask = torch.where(scores != 0.0, scores, torch.tensor(DEFAULT_MASK_VALUE, device=device))
padded_logits = torch.where((mask >= DEFAULT_MASK_VALUE * 0.5), scores, torch.tensor(DEFAULT_MASK_VALUE, device=device))
scores = F.softmax(padded_logits, dim=-1).to(torch.float32)
output = torch.matmul(scores, values)
output = output.transpose(1, 2).reshape(xq.shape[0], xq.shape[2], -1)
out = F.linear(output, layer_weights.wo)
out = F.linear(output, layer_weights.wo.to(device))

return out, kvcache, pre_scores

def feed_forward(x: torch.Tensor, layer_weights: LayerWeights) -> torch.Tensor:
Expand All @@ -77,4 +94,4 @@ def xfmr(xfmr_weights: XfmrWeights, model_params: ModelParams, tokens: torch.Ten
h = h + h_attn
h = h + feed_forward(rms_norm(h, xfmr_weights.layer_weights[i].ffn_norm), xfmr_weights.layer_weights[i])
logits = F.linear(rms_norm(h, xfmr_weights.norm), xfmr_weights.output)
return logits, kvcache, scores, attn_stats
return logits, kvcache, scores, attn_stats
2 changes: 1 addition & 1 deletion entropix/torch_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def load_weights(ckpt_dir: Path = Path('weights/1B-Instruct'), n_layers: int = 1
ffn_norm=w[f'layers.{i}.ffn_norm.weight'],
attention_norm=w[f'layers.{i}.attention_norm.weight'],
))

xfmr_weights = XfmrWeights(
tok_embeddings=w['tok_embeddings.weight'],
norm=w['norm.weight'],
Expand Down
21 changes: 21 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Main dependencies
flax==0.9.0
tiktoken==0.4.0
pydantic==2.9.2
blobfile==3.0.0
ml-dtypes==0.5.0
rich==13.8.1
torch==2.4.1 ; platform_system != 'Darwin'
chex==0.1.87
tyro==0.8.11

# Dev dependencies
pytest==8.3.2
ruff==0.6.2
transformers==4.45.1
huggingface-hub[cli]==0.25.1

# Test dependencies
fairscale==0.4.13

torch

0 comments on commit c9cc018

Please sign in to comment.