forked from THUDM/CogVLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcli_demo.py
154 lines (143 loc) · 6.52 KB
/
cli_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# -*- encoding: utf-8 -*-
import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import argparse
from sat.model.mixins import CachedAutoregressiveMixin
from utils.chat import chat
from models.cogvlm_model import CogVLMModel
from utils.language import llama2_tokenizer, llama2_text_processor_inference
from utils.vision import get_image_processor
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence')
parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling')
parser.add_argument("--top_k", type=int, default=1, help='top k for top k sampling')
parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling')
parser.add_argument("--english", action='store_true', help='only output English')
parser.add_argument("--version", type=str, default="chat", help='version to interact with')
parser.add_argument("--from_pretrained", type=str, default="cogvlm-chat", help='pretrained ckpt')
parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path')
parser.add_argument("--no_prompt", action='store_true', help='Sometimes there is no prompt in stage 1')
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--bf16", action="store_true")
args = parser.parse_args()
rank = int(os.environ.get('RANK', 0))
world_size = int(os.environ.get('WORLD_SIZE', 1))
parser = CogVLMModel.add_model_specific_args(parser)
args = parser.parse_args()
# load model
model, model_args = CogVLMModel.from_pretrained(
args.from_pretrained,
args=argparse.Namespace(
deepspeed=None,
local_rank=rank,
rank=rank,
world_size=world_size,
model_parallel_size=world_size,
mode='inference',
skip_init=True,
use_gpu_initialization=True if torch.cuda.is_available() else False,
device='cuda',
**vars(args)
), overwrite_args={'model_parallel_size': world_size} if world_size != 1 else {})
model = model.eval()
from sat.mpu import get_model_parallel_world_size
assert world_size == get_model_parallel_world_size(), "world size must equal to model parallel size for cli_demo!"
tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version)
image_processor = get_image_processor(model_args.eva_args["image_size"][0])
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
text_processor_infer = llama2_text_processor_inference(tokenizer, args.max_length, model.image_length)
if not args.english:
if rank == 0:
print('欢迎使用 CogVLM-CLI ,输入图像URL或本地路径读图,继续输入内容对话,clear 重新开始,stop 终止程序')
else:
if rank == 0:
print('Welcome to CogVLM-CLI. Enter an image URL or local file path to load an image. Continue inputting text to engage in a conversation. Type "clear" to start over, or "stop" to end the program.')
with torch.no_grad():
while True:
history = None
cache_image = None
if not args.english:
if rank == 0:
image_path = [input("请输入图像路径或URL(回车进入纯文本对话): ")]
else:
image_path = [None]
else:
if rank == 0:
image_path = [input("Please enter the image path or URL (press Enter for plain text conversation): ")]
else:
image_path = [None]
if world_size > 1:
torch.distributed.broadcast_object_list(image_path, 0)
image_path = image_path[0]
assert image_path is not None
if image_path == 'stop':
break
if args.no_prompt and len(image_path) > 0:
query = ""
else:
if not args.english:
if rank == 0:
query = [input("用户:")]
else:
query = [None]
else:
if rank == 0:
query = [input("User: ")]
else:
query = [None]
if world_size > 1:
torch.distributed.broadcast_object_list(query, 0)
query = query[0]
assert query is not None
while True:
if query == "clear":
break
if query == "stop":
sys.exit(0)
try:
response, history, cache_image = chat(
image_path,
model,
text_processor_infer,
image_processor,
query,
history=history,
image=cache_image,
max_length=args.max_length,
top_p=args.top_p,
temperature=args.temperature,
top_k=args.top_k,
invalid_slices=text_processor_infer.invalid_slices,
no_prompt=args.no_prompt
)
except Exception as e:
print(e)
break
if rank == 0:
if not args.english:
print("模型:"+response)
if tokenizer.signal_type == "grounding":
print("Grounding 结果已保存至 ./output.png")
else:
print("Model: "+response)
if tokenizer.signal_type == "grounding":
print("Grounding result is saved at ./output.png")
image_path = None
if not args.english:
if rank == 0:
query = [input("用户:")]
else:
query = [None]
else:
if rank == 0:
query = [input("User: ")]
else:
query = [None]
if world_size > 1:
torch.distributed.broadcast_object_list(query, 0)
query = query[0]
assert query is not None
if __name__ == "__main__":
main()