Skip to content

Commit 5884f55

Browse files
committed
update
1 parent 61b889e commit 5884f55

15 files changed

+2415
-45
lines changed

CNAME

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
www.luokai.tech

_config.yml

+45-45
Original file line numberDiff line numberDiff line change
@@ -14,62 +14,62 @@ timezone:
1414
# jekyll-seo-tag settings › https://github.com/jekyll/jekyll-seo-tag/blob/master/docs/usage.md
1515
# ↓ --------------------------
1616

17-
title: Chirpy # the main title
17+
title: K's blog # the main title
1818

19-
tagline: A text-focused Jekyll theme # it will display as the sub-title
19+
tagline: 个人笔记记录 # it will display as the sub-title
2020

2121
description: >- # used by seo meta and the atom feed
22-
A minimal, responsive and feature-rich Jekyll theme for technical writing.
22+
个人笔记记录
2323
2424
# Fill in the protocol & hostname for your site.
2525
# e.g. 'https://username.github.io', note that it does not end with a '/'.
2626
url: ""
2727

28-
github:
29-
username: github_username # change to your github username
30-
31-
twitter:
32-
username: twitter_username # change to your twitter username
33-
34-
social:
35-
# Change to your full name.
36-
# It will be displayed as the default author of the posts and the copyright owner in the Footer
37-
name: your_full_name
38-
email: [email protected] # change to your email address
39-
links:
40-
# The first element serves as the copyright owner's link
41-
- https://twitter.com/username # change to your twitter homepage
42-
- https://github.com/username # change to your github homepage
43-
# Uncomment below to add more social links
44-
# - https://www.facebook.com/username
45-
# - https://www.linkedin.com/in/username
46-
47-
# Site Verification Settings
48-
webmaster_verifications:
49-
google: # fill in your Google verification code
50-
bing: # fill in your Bing verification code
51-
alexa: # fill in your Alexa verification code
52-
yandex: # fill in your Yandex verification code
53-
baidu: # fill in your Baidu verification code
54-
facebook: # fill in your Facebook verification code
28+
# github:
29+
# username: github_username # change to your github username
30+
31+
# twitter:
32+
# username: twitter_username # change to your twitter username
33+
34+
# social:
35+
# # Change to your full name.
36+
# # It will be displayed as the default author of the posts and the copyright owner in the Footer
37+
# name: your_full_name
38+
# email: [email protected] # change to your email address
39+
# links:
40+
# # The first element serves as the copyright owner's link
41+
# - https://twitter.com/username # change to your twitter homepage
42+
# - https://github.com/username # change to your github homepage
43+
# # Uncomment below to add more social links
44+
# # - https://www.facebook.com/username
45+
# # - https://www.linkedin.com/in/username
46+
47+
# # Site Verification Settings
48+
# webmaster_verifications:
49+
# google: # fill in your Google verification code
50+
# bing: # fill in your Bing verification code
51+
# alexa: # fill in your Alexa verification code
52+
# yandex: # fill in your Yandex verification code
53+
# baidu: # fill in your Baidu verification code
54+
# facebook: # fill in your Facebook verification code
5555

5656
# ↑ --------------------------
5757
# The end of `jekyll-seo-tag` settings
5858

59-
# Web Analytics Settings
60-
analytics:
61-
google:
62-
id: # fill in your Google Analytics ID
63-
goatcounter:
64-
id: # fill in your GoatCounter ID
65-
umami:
66-
id: # fill in your Umami ID
67-
domain: # fill in your Umami domain
68-
matomo:
69-
id: # fill in your Matomo ID
70-
domain: # fill in your Matomo domain
71-
cloudflare:
72-
id: # fill in your Cloudflare Web Analytics token
59+
# # Web Analytics Settings
60+
# analytics:
61+
# google:
62+
# id: # fill in your Google Analytics ID
63+
# goatcounter:
64+
# id: # fill in your GoatCounter ID
65+
# umami:
66+
# id: # fill in your Umami ID
67+
# domain: # fill in your Umami domain
68+
# matomo:
69+
# id: # fill in your Matomo ID
70+
# domain: # fill in your Matomo domain
71+
# cloudflare:
72+
# id: # fill in your Cloudflare Web Analytics token
7373

7474
# Pageviews settings
7575
pageviews:
@@ -86,7 +86,7 @@ pageviews:
8686
# light - Use the light color scheme
8787
# dark - Use the dark color scheme
8888
#
89-
theme_mode: # [light | dark]
89+
theme_mode: light # [light | dark]
9090

9191
# The CDN endpoint for media resources.
9292
# Notice that once it is assigned, the CDN url
+193
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
---
2+
title: 开源大语言模型 chatglm 简单的并发改造
3+
categories: [深度学习]
4+
description: 对 chatglm 开源版本进行修改,提升并发能力
5+
keywords:
6+
- chatglm
7+
- llm
8+
- 大语言模型
9+
- 并发
10+
date: 2023-11-01
11+
draft: false
12+
---
13+
14+
### 总结
15+
- 开源的 chatglm3-6b 只提供了连续生成的api,实际部署使用时,在只用了一个workers的情况下,如果有多人同时提问,必须要等到前一个回答全部结束后才会开始回答下一个问题,在用户端的感觉是等待时间过长,于是我参照chatglm3源码写了一个简单的并发api,显存要求更高一点,不过当有多人同时提问时,可以同时进行回答,回答速度会变慢,可以理解成是并发用户均分 token 生成速度。
16+
- 方案为临时使用,后续使用其他的高性能推理框架替代
17+
18+
### 整体思路
19+
修改generate函数,不是连续生成一整句,每次只做一次推理,使用fastapi写一个请求端服务,附带上下文进行多次请求,请求服务有多个workers时可以处理并发,不需要等一整句生成完成后再生成下一句
20+
21+
#### 实现过程
22+
推理服务 api.py
23+
```python
24+
class Message(BaseModel):
25+
cache_id: str
26+
query: str
27+
history: List[List[str]|Any] = []
28+
model_name: str = "chatglm3-6b"
29+
temperature: float = 0.95
30+
top_p: float = 0.7
31+
max_length: int = 8192
32+
do_sample: bool = True
33+
34+
35+
class CacheMessage(BaseModel):
36+
flag: str
37+
delta_text: str
38+
39+
40+
class InvalidScoreLogitsProcessor(LogitsProcessor):
41+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
42+
if torch.isnan(scores).any() or torch.isinf(scores).any():
43+
scores.zero_()
44+
scores[..., 5] = 5e4
45+
return scores
46+
47+
48+
class ChatModel:
49+
def __init__(self, model_path: str = "/data/git_source/huggingface/THUDM/chatglm3-6b"):
50+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
51+
self.device = "cuda"
52+
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).cuda()
53+
self.model.eval()
54+
# self.redis = redis.Redis(host='localhost', port=6379, db=0, password="redispass")
55+
self.logits_processor = LogitsProcessorList()
56+
self.logits_processor.append(InvalidScoreLogitsProcessor())
57+
self.stopping_criteria = StoppingCriteriaList()
58+
# 内存中保存还没回答完的句子数据
59+
self.cache = {}
60+
61+
# 参考 chatglm 源码修改,每次生成只推理一次
62+
@torch.inference_mode()
63+
def generate(self, message: Message) -> CacheMessage:
64+
gen_kwargs = {"max_length": message.max_length,
65+
"do_sample": message.do_sample,
66+
"top_p": message.top_p,
67+
"temperature": message.temperature,
68+
"logits_processor": self.logits_processor}
69+
kwargs = gen_kwargs
70+
# 是否是新的句子
71+
if message.cache_id in self.cache:
72+
msg = self.cache[message.cache_id]
73+
if msg["flag"] == "end":
74+
del self.cache[message.cache_id]
75+
return {"flag": msg["flag"],
76+
"delta_text": msg["delta_text"]}
77+
input_ids = msg["input_ids"]
78+
model_kwargs = self.cache[message.cache_id]["model_kwargs"]
79+
# 新句子生成一个唯一id
80+
else:
81+
inputs = self.tokenizer.build_chat_input(message.query, history=message.history, role="user")
82+
input_ids = inputs["input_ids"].to(self.device)
83+
model_kwargs = self.model.generation_config.update(**kwargs)
84+
model_kwargs["use_cache"] = self.model.generation_config.use_cache
85+
msg = {
86+
"flag": "sending",
87+
"input_ids": input_ids,
88+
"model_kwargs": model_kwargs,
89+
"input_ids_raw_len": input_ids.shape[1],
90+
"previous_text": "",
91+
"delta_text": "",
92+
"create": datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
93+
}
94+
self.cache[message.cache_id] = msg
95+
# 推理过程
96+
_, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
97+
_, eos_token_id = self.model.generation_config.bos_token_id, self.model.generation_config.eos_token_id
98+
eos_token_id = [eos_token_id]
99+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
100+
logits_processor = self.model._get_logits_processor(
101+
generation_config=self.model.generation_config,
102+
input_ids_seq_length=input_ids_seq_length,
103+
encoder_input_ids=input_ids,
104+
prefix_allowed_tokens_fn=None,
105+
logits_processor=self.logits_processor,
106+
)
107+
108+
stopping_criteria = self.model._get_stopping_criteria(
109+
generation_config=self.model.generation_config, stopping_criteria=self.stopping_criteria
110+
)
111+
logits_warper = self.model._get_logits_warper(self.model.generation_config)
112+
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
113+
model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs)
114+
outputs = self.model(
115+
**model_inputs,
116+
return_dict=True,
117+
output_attentions=False,
118+
output_hidden_states=False,
119+
)
120+
next_token_logits = outputs.logits[:, -1, :]
121+
next_token_scores = logits_processor(input_ids, next_token_logits)
122+
next_token_scores = logits_warper(input_ids, next_token_scores)
123+
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
124+
if self.model.generation_config.do_sample:
125+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
126+
else:
127+
next_tokens = torch.argmax(probs, dim=-1)
128+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
129+
model_kwargs = self.model._update_model_kwargs_for_generation(
130+
outputs, model_kwargs, is_encoder_decoder=self.model.config.is_encoder_decoder
131+
)
132+
unfinished_sequences = unfinished_sequences.mul(
133+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
134+
)
135+
response = self.tokenizer.decode(input_ids.tolist()[0][self.cache[message.cache_id]["input_ids_raw_len"]:-1])
136+
self.cache[message.cache_id]["input_ids"] = input_ids
137+
if response:
138+
delta_text = response[len(self.cache[message.cache_id]["previous_text"]):]
139+
self.cache[message.cache_id]["delta_text"] = delta_text
140+
if response[-1] != "":
141+
self.cache[message.cache_id]["flag"] = "sending"
142+
self.cache[message.cache_id]["previous_text"] = response
143+
else:
144+
self.cache[message.cache_id]["flag"] = "hang"
145+
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, None):
146+
self.cache[message.cache_id]["flag"] = "end"
147+
148+
gc.collect()
149+
torch.cuda.empty_cache()
150+
return {"flag": self.cache[message.cache_id]["flag"],
151+
"delta_text": self.cache[message.cache_id]["delta_text"]}
152+
```
153+
请求推理的服务,chatglm.py
154+
```python
155+
async def stream_chat(self, prompt: str, history: List[List[str]] = [], **kw):
156+
for k in self.chat_config:
157+
if k not in kw:
158+
kw[k] = self.chat_config[k]
159+
msg_history = []
160+
if len(history) > 0:
161+
for q, a in history:
162+
msg_history.append({"role": "user", "content": q})
163+
msg_history.append({"role": "assistant", "content": a})
164+
msg_history.append({"role": "user", "content": prompt})
165+
msg = {
166+
"cache_id": str(uuid.uuid4()),
167+
"query": prompt,
168+
"history": msg_history,
169+
**kw}
170+
headers = {'Content-Type': 'application/json'}
171+
history += [[]]
172+
# 多次请求推理服务,直到触发句子结束,句子结束后再次请求,会重新推理生成一遍
173+
while True:
174+
payload = json.dumps(msg)
175+
response = requests.post(f"http://{self.config['server_url']}/llm/generate", headers=headers, data=payload)
176+
if response.status_code == 200:
177+
resp = response.json()
178+
self.loginfo(f"raw response: delta_text {resp['delta_text']}")
179+
if resp["flag"] in ("sending", "end"):
180+
r = resp["delta_text"]
181+
history[-1] = [prompt, r]
182+
answer_result = AnswerResult()
183+
answer_result.history = history
184+
answer_result.llm_output = {"answer": r}
185+
yield answer_result
186+
if resp["flag"] == "end":
187+
break
188+
else:
189+
break
190+
```
191+
需要起两个服务,api.py 的服务只跑一个 workers(显存够大的话也可以跑多个),chatglm.py 按照并发要求跑多个 workers。
192+
193+
实际使用可以发现,当有多个问题同时提交时,后提交的不需要再等前一个回答完成才收到流式回复,而是会立刻开始收到回答。

0 commit comments

Comments
 (0)