4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+
8
+ import argparse
9
+
7
10
import torch
8
11
9
12
from executorch .backends .transforms .duplicate_dynamic_quant_chain import (
20
23
XNNPACKQuantizer ,
21
24
)
22
25
23
- from transformers import Phi3ForCausalLM
26
+ from transformers import AutoTokenizer , Phi3ForCausalLM
27
+
28
+ from .phi_3_mini import Phi3Mini
24
29
25
30
26
- def main () -> None :
31
+ def main (args ) -> None :
27
32
torch .manual_seed (0 )
28
33
29
- # pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `Phi3ForCausalLM`
30
- model = Phi3ForCausalLM .from_pretrained ("microsoft/Phi-3-mini-4k-instruct" )
34
+ model_name = "microsoft/Phi-3-mini-4k-instruct"
31
35
32
- example_inputs = (torch .randint (0 , 100 , (1 , 100 ), dtype = torch .long ),)
33
- dynamic_shape = {"input_ids" : {1 : torch .export .Dim ("sequence_length" , max = 128 )}}
36
+ with torch .no_grad ():
37
+ model = Phi3Mini (
38
+ # pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `Phi3ForCausalLM`
39
+ model = Phi3ForCausalLM .from_pretrained (model_name ),
40
+ max_batch_size = 1 ,
41
+ max_seq_len = args .seq_len ,
42
+ )
43
+ tokenizer = AutoTokenizer .from_pretrained (model_name )
34
44
35
- xnnpack_quant_config = get_symmetric_quantization_config (
36
- is_per_channel = True , is_dynamic = True
37
- )
38
- xnnpack_quantizer = XNNPACKQuantizer ()
39
- xnnpack_quantizer .set_global (xnnpack_quant_config )
40
-
41
- with torch .nn .attention .sdpa_kernel (
42
- [torch .nn .attention .SDPBackend .MATH ]
43
- ), torch .no_grad ():
44
- model = capture_pre_autograd_graph (
45
- model , example_inputs , dynamic_shapes = dynamic_shape
45
+ tokens = tokenizer .encode ("Tell me a story" , return_tensors = "pt" )
46
+ for input_pos in range (tokens .shape [- 1 ]):
47
+ result = model .forward (
48
+ input_ids = tokens [:, input_pos : input_pos + 1 ],
49
+ )
50
+ current_token = torch .argmax (result , dim = - 1 ).item ()
51
+
52
+ example_inputs = (
53
+ torch .tensor ([[current_token ]], dtype = torch .long , requires_grad = False ),
54
+ )
55
+
56
+ xnnpack_quant_config = get_symmetric_quantization_config (
57
+ is_per_channel = True , is_dynamic = True
46
58
)
59
+ xnnpack_quantizer = XNNPACKQuantizer ()
60
+ xnnpack_quantizer .set_global (xnnpack_quant_config )
61
+
62
+ model = capture_pre_autograd_graph (model , example_inputs )
47
63
model = prepare_pt2e (model , xnnpack_quantizer )
48
64
model (* example_inputs )
49
65
model = convert_pt2e (model , fold_quantize = False )
@@ -53,19 +69,26 @@ def main() -> None:
53
69
model = torch .export ._trace ._export (
54
70
model ,
55
71
example_inputs ,
56
- dynamic_shapes = dynamic_shape ,
57
72
strict = False ,
58
73
pre_dispatch = False ,
59
74
)
60
75
61
76
edge_config = get_xnnpack_edge_compile_config ()
62
77
edge_manager = to_edge (model , compile_config = edge_config )
63
- edge_manager = edge_manager .to_backend (XnnpackPartitioner (has_dynamic_shapes = True ))
78
+ edge_manager = edge_manager .to_backend (XnnpackPartitioner ())
64
79
et_program = edge_manager .to_executorch ()
65
80
66
81
with open ("phi-3-mini.pte" , "wb" ) as file :
67
82
file .write (et_program .buffer )
68
83
69
84
70
85
if __name__ == "__main__" :
71
- main ()
86
+ parser = argparse .ArgumentParser ()
87
+ parser .add_argument (
88
+ "-s" ,
89
+ "--seq_len" ,
90
+ type = int ,
91
+ default = 128 ,
92
+ help = "Maximum number of tokens including prompt to generate" ,
93
+ )
94
+ main (parser .parse_args ())
0 commit comments