Skip to content

Commit a65700c

Browse files
helunwencserfacebook-github-bot
authored andcommitted
add a wrapper for running phi-3-mini with kv cache (#4491)
Summary: Pull Request resolved: #4491 imported-using-ghimport Test Plan: Imported from OSS Reviewed By: iseeyuan Differential Revision: D60554454 Pulled By: helunwencser fbshipit-source-id: 01974a94ac1826cf63e796247a0200128293d27d
1 parent 5b37524 commit a65700c

File tree

3 files changed

+55
-21
lines changed

3 files changed

+55
-21
lines changed
+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .phi_3_mini import Phi3Mini
8+
9+
__all__ = [
10+
Phi3Mini,
11+
]

examples/models/phi-3-mini/eager.py

+8-21
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from transformers import AutoTokenizer, Phi3ForCausalLM
1616

17-
from .static_cache import ETStaticCache
17+
from .phi_3_mini import Phi3Mini
1818

1919
end_of_text_token = 32000
2020

@@ -42,35 +42,22 @@ def _generate_token(args, model, prompt_tokens):
4242
def _generate_token_with_kv_cache(args, model, prompt_tokens):
4343
print("Generating tokens:", end="", flush=True)
4444

45-
result = model.forward(
46-
input_ids=prompt_tokens,
47-
use_cache=True,
48-
return_dict=True,
49-
past_key_values=ETStaticCache(
50-
model.config,
51-
prompt_tokens.shape[0],
52-
args.seq_len + prompt_tokens.shape[-1],
53-
device=model.device,
54-
dtype=model.dtype,
55-
),
56-
)
45+
model = Phi3Mini(model, 1, args.seq_len + prompt_tokens.shape[-1])
5746

58-
current_token = torch.argmax(result.logits[:, -1, :], dim=-1).item()
59-
current_key_value = result.past_key_values
47+
for input_pos in range(prompt_tokens.shape[-1]):
48+
result = model.forward(
49+
input_ids=prompt_tokens[:, input_pos : input_pos + 1],
50+
)
6051

52+
current_token = torch.argmax(result, dim=-1).item()
6153
print(f" {current_token}", end="", flush=True)
62-
6354
generated_tokens = [current_token]
6455

6556
while current_token != end_of_text_token and len(generated_tokens) < args.seq_len:
6657
result = model.forward(
6758
input_ids=torch.tensor([[current_token]], dtype=torch.long),
68-
use_cache=True,
69-
return_dict=True,
70-
past_key_values=current_key_value,
7159
)
72-
current_token = torch.argmax(result.logits[:, -1, :], dim=-1).item()
73-
current_key_value = result.past_key_values
60+
current_token = torch.argmax(result, dim=-1).item()
7461
print(f" {current_token}", end="", flush=True)
7562
generated_tokens.append(current_token)
7663

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch.nn
9+
from transformers import Phi3ForCausalLM
10+
11+
from .static_cache import ETStaticCache
12+
13+
14+
class Phi3Mini(torch.nn.Module):
15+
16+
def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int):
17+
super().__init__()
18+
self.model = model
19+
self.cache = ETStaticCache(
20+
config=model.config,
21+
max_batch_size=max_batch_size,
22+
max_cache_len=max_seq_len,
23+
device=self.model.device,
24+
dtype=self.model.dtype,
25+
)
26+
27+
def forward(
28+
self,
29+
input_ids: torch.LongTensor = None,
30+
) -> torch.FloatTensor:
31+
return self.model.forward(
32+
input_ids=input_ids,
33+
use_cache=True,
34+
return_dict=True,
35+
past_key_values=self.cache,
36+
).logits[:, -1, :]

0 commit comments

Comments
 (0)