-
Notifications
You must be signed in to change notification settings - Fork 15
/
Demo.py
95 lines (80 loc) · 3.27 KB
/
Demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import fire
import sys
import os
import json
from pathlib import Path
from typing import List
import torch
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../..")
import llama.modeling.Loader as Loader
from Tokenizer import Tokenizer, make_context, decode_context
from ModelParams import ModelParams
def main(
ckpt_dir: str,
tokenizer_path: str,
temperature: float = 0.0,
top_p: float = 0.95,
batch: int = 1,
seqlen_scale_up: int = 1,
max_gen_len: int = 512,
friendly_gqa: bool = False, # done gqa by repeating key and value by key_value_cache op
fused_qkv: bool = True, # fuse qkv linear
fused_kvcache: bool = True, # fuse key_value_cache and multi_head_attention
fused_ffn_glu: bool = True, # fuse feed forward gate linear unit
auto_causal: bool = True, # causal mask is auto done by attention op, no need to pass additional mask to the model
quantized_cache: bool = True, # 8bit kv cache quantization
cache_layout: int = 0, # change kv cache layout for hardware performance friendly
cache_mode: int = 0, # change kv cache indexing mode for memory management friendly, only affected when dynamic_batching == True
dynamic_batching: bool = True, # use dynamic batching scheduling
context_chunking: bool = True, # enable context chunking for dynamic batching
dump_tensor_path: str = '',
dump_steps: List[int] = []
):
tokenizer = Tokenizer(model_path=tokenizer_path)
with open(Path(ckpt_dir) / "opmx_params.json", "r") as f:
params = json.loads(f.read())
params: ModelParams = ModelParams(**params)
# attn_wqkv_bias_term
generator = Loader.load(
ckpt_dir, params,
friendly_gqa=friendly_gqa,
fused_qkv=fused_qkv,
fused_kvcache=fused_kvcache,
fused_ffn_glu=fused_ffn_glu,
fused_alibi=False,
auto_causal=auto_causal,
with_rope=True,
with_alibi=False,
quantized_cache=quantized_cache,
cache_layout=cache_layout,
cache_mode=cache_mode,
dynamic_batching=dynamic_batching,
attn_wqkv_bias_term=True,
attn_wo_bias_term=False,
ffn_linear_bias_term=False,
load_to_cpu=False,
rotary_dim=0,
dump_tensor_path=dump_tensor_path,
dump_steps=dump_steps
)
generator.context_chunking = context_chunking if dynamic_batching else False
test_prompt = "I believe the meaning of life is"
raw_text, test_prompt = make_context(tokenizer, test_prompt)
_scale_up_prompt = []
for _ in range(seqlen_scale_up):
_scale_up_prompt.extend(test_prompt)
test_prompt = _scale_up_prompt
prompt_tokens = [test_prompt for _ in range(batch)]
print(f"prepared {len(prompt_tokens)} prompts")
results = generator.generate(
prompt_tokens[:batch], tokenizer.get_eos_id(), tokenizer.get_pad_id(),
max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, top_k=0
)
for result in results:
if torch.is_tensor(result):
result = result.cpu().numpy().tolist()
result = decode_context(result, tokenizer=tokenizer, raw_text_len=len(raw_text), context_length=len(test_prompt))
print(result)
print("\n==================================\n")
if __name__ == "__main__":
fire.Fire(main)