|
14 | 14 |
|
15 | 15 | from transformers import AutoTokenizer, Phi3ForCausalLM
|
16 | 16 |
|
17 |
| -from .static_cache import ETStaticCache |
| 17 | +from .phi_3_mini import Phi3Mini |
18 | 18 |
|
19 | 19 | end_of_text_token = 32000
|
20 | 20 |
|
@@ -42,35 +42,22 @@ def _generate_token(args, model, prompt_tokens):
|
42 | 42 | def _generate_token_with_kv_cache(args, model, prompt_tokens):
|
43 | 43 | print("Generating tokens:", end="", flush=True)
|
44 | 44 |
|
45 |
| - result = model.forward( |
46 |
| - input_ids=prompt_tokens, |
47 |
| - use_cache=True, |
48 |
| - return_dict=True, |
49 |
| - past_key_values=ETStaticCache( |
50 |
| - model.config, |
51 |
| - prompt_tokens.shape[0], |
52 |
| - args.seq_len + prompt_tokens.shape[-1], |
53 |
| - device=model.device, |
54 |
| - dtype=model.dtype, |
55 |
| - ), |
56 |
| - ) |
| 45 | + model = Phi3Mini(model, 1, args.seq_len + prompt_tokens.shape[-1]) |
57 | 46 |
|
58 |
| - current_token = torch.argmax(result.logits[:, -1, :], dim=-1).item() |
59 |
| - current_key_value = result.past_key_values |
| 47 | + for input_pos in range(prompt_tokens.shape[-1]): |
| 48 | + result = model.forward( |
| 49 | + input_ids=prompt_tokens[:, input_pos : input_pos + 1], |
| 50 | + ) |
60 | 51 |
|
| 52 | + current_token = torch.argmax(result, dim=-1).item() |
61 | 53 | print(f" {current_token}", end="", flush=True)
|
62 |
| - |
63 | 54 | generated_tokens = [current_token]
|
64 | 55 |
|
65 | 56 | while current_token != end_of_text_token and len(generated_tokens) < args.seq_len:
|
66 | 57 | result = model.forward(
|
67 | 58 | input_ids=torch.tensor([[current_token]], dtype=torch.long),
|
68 |
| - use_cache=True, |
69 |
| - return_dict=True, |
70 |
| - past_key_values=current_key_value, |
71 | 59 | )
|
72 |
| - current_token = torch.argmax(result.logits[:, -1, :], dim=-1).item() |
73 |
| - current_key_value = result.past_key_values |
| 60 | + current_token = torch.argmax(result, dim=-1).item() |
74 | 61 | print(f" {current_token}", end="", flush=True)
|
75 | 62 | generated_tokens.append(current_token)
|
76 | 63 |
|
|
0 commit comments