Skip to content

Commit a743a3b

Browse files
helunwencserfacebook-github-bot
authored andcommitted
export phi-3-mini-wrapper (#4478)
Summary: Pull Request resolved: #4478 imported-using-ghimport Test Plan: Imported from OSS Reviewed By: iseeyuan Differential Revision: D60483506 Pulled By: helunwencser fbshipit-source-id: f5f019035af66af6380186e4bc57a949e6cc5480
1 parent a65700c commit a743a3b

File tree

1 file changed

+43
-20
lines changed

1 file changed

+43
-20
lines changed

examples/models/phi-3-mini/export_phi-3-mini.py

+43-20
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
8+
import argparse
9+
710
import torch
811

912
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
@@ -20,30 +23,43 @@
2023
XNNPACKQuantizer,
2124
)
2225

23-
from transformers import Phi3ForCausalLM
26+
from transformers import AutoTokenizer, Phi3ForCausalLM
27+
28+
from .phi_3_mini import Phi3Mini
2429

2530

26-
def main() -> None:
31+
def main(args) -> None:
2732
torch.manual_seed(0)
2833

29-
# pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `Phi3ForCausalLM`
30-
model = Phi3ForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
34+
model_name = "microsoft/Phi-3-mini-4k-instruct"
3135

32-
example_inputs = (torch.randint(0, 100, (1, 100), dtype=torch.long),)
33-
dynamic_shape = {"input_ids": {1: torch.export.Dim("sequence_length", max=128)}}
36+
with torch.no_grad():
37+
model = Phi3Mini(
38+
# pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `Phi3ForCausalLM`
39+
model=Phi3ForCausalLM.from_pretrained(model_name),
40+
max_batch_size=1,
41+
max_seq_len=args.seq_len,
42+
)
43+
tokenizer = AutoTokenizer.from_pretrained(model_name)
3444

35-
xnnpack_quant_config = get_symmetric_quantization_config(
36-
is_per_channel=True, is_dynamic=True
37-
)
38-
xnnpack_quantizer = XNNPACKQuantizer()
39-
xnnpack_quantizer.set_global(xnnpack_quant_config)
40-
41-
with torch.nn.attention.sdpa_kernel(
42-
[torch.nn.attention.SDPBackend.MATH]
43-
), torch.no_grad():
44-
model = capture_pre_autograd_graph(
45-
model, example_inputs, dynamic_shapes=dynamic_shape
45+
tokens = tokenizer.encode("Tell me a story", return_tensors="pt")
46+
for input_pos in range(tokens.shape[-1]):
47+
result = model.forward(
48+
input_ids=tokens[:, input_pos : input_pos + 1],
49+
)
50+
current_token = torch.argmax(result, dim=-1).item()
51+
52+
example_inputs = (
53+
torch.tensor([[current_token]], dtype=torch.long, requires_grad=False),
54+
)
55+
56+
xnnpack_quant_config = get_symmetric_quantization_config(
57+
is_per_channel=True, is_dynamic=True
4658
)
59+
xnnpack_quantizer = XNNPACKQuantizer()
60+
xnnpack_quantizer.set_global(xnnpack_quant_config)
61+
62+
model = capture_pre_autograd_graph(model, example_inputs)
4763
model = prepare_pt2e(model, xnnpack_quantizer)
4864
model(*example_inputs)
4965
model = convert_pt2e(model, fold_quantize=False)
@@ -53,19 +69,26 @@ def main() -> None:
5369
model = torch.export._trace._export(
5470
model,
5571
example_inputs,
56-
dynamic_shapes=dynamic_shape,
5772
strict=False,
5873
pre_dispatch=False,
5974
)
6075

6176
edge_config = get_xnnpack_edge_compile_config()
6277
edge_manager = to_edge(model, compile_config=edge_config)
63-
edge_manager = edge_manager.to_backend(XnnpackPartitioner(has_dynamic_shapes=True))
78+
edge_manager = edge_manager.to_backend(XnnpackPartitioner())
6479
et_program = edge_manager.to_executorch()
6580

6681
with open("phi-3-mini.pte", "wb") as file:
6782
file.write(et_program.buffer)
6883

6984

7085
if __name__ == "__main__":
71-
main()
86+
parser = argparse.ArgumentParser()
87+
parser.add_argument(
88+
"-s",
89+
"--seq_len",
90+
type=int,
91+
default=128,
92+
help="Maximum number of tokens including prompt to generate",
93+
)
94+
main(parser.parse_args())

0 commit comments

Comments
 (0)