forked from THUDM/ChatGLM3
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
1,115 additions
and
1 deletion.
There are no files selected for viewing
Submodule composite_demo
deleted from
496992
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[theme] | ||
font = "monospace" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# ChatGLM3 Web Demo | ||
|
||
 | ||
|
||
## 安装 | ||
|
||
我们建议通过 [Conda](https://docs.conda.io/en/latest/) 进行环境管理。 | ||
|
||
执行以下命令新建一个 conda 环境并安装所需依赖: | ||
|
||
```bash | ||
conda create -n chatglm3-demo python=3.10 | ||
conda activate chatglm3-demo | ||
pip install -r requirements.txt | ||
``` | ||
|
||
请注意,本项目需要 Python 3.10 或更高版本。 | ||
|
||
此外,使用 Code Interpreter 还需要安装 Jupyter 内核: | ||
|
||
```bash | ||
ipython kernel install --name chatglm3-demo --user | ||
``` | ||
|
||
## 运行 | ||
|
||
运行以下命令在本地加载模型并启动 demo: | ||
|
||
```bash | ||
streamlit run main.py | ||
``` | ||
|
||
之后即可从命令行中看到 demo 的地址,点击即可访问。初次访问需要下载并加载模型,可能需要花费一定时间。 | ||
|
||
如果已经在本地下载了模型,可以通过 `export MODEL_PATH=/path/to/model` 来指定从本地加载模型。如果需要自定义 Jupyter 内核,可以通过 `export IPYKERNEL=<kernel_name>` 来指定。 | ||
|
||
## 使用 | ||
|
||
ChatGLM3 Demo 拥有三种模式: | ||
|
||
- Chat: 对话模式,在此模式下可以与模型进行对话。 | ||
- Tool: 工具模式,模型除了对话外,还可以通过工具进行其他操作。 | ||
- Code Interpreter: 代码解释器模式,模型可以在一个 Jupyter 环境中执行代码并获取结果,以完成复杂任务。 | ||
|
||
### 对话模式 | ||
|
||
对话模式下,用户可以直接在侧边栏修改 top_p, temperature, System Prompt 等参数来调整模型的行为。例如 | ||
|
||
 | ||
|
||
### 工具模式 | ||
|
||
可以通过在 `tool_registry.py` 中注册新的工具来增强模型的能力。只需要使用 `@register_tool` 装饰函数即可完成注册。对于工具声明,函数名称即为工具的名称,函数 docstring 即为工具的说明;对于工具的参数,使用 `Annotated[typ: type, description: str, required: bool]` 标注参数的类型、描述和是否必须。 | ||
|
||
例如,`get_weather` 工具的注册如下: | ||
|
||
```python | ||
@register_tool | ||
def get_weather( | ||
city_name: Annotated[str, 'The name of the city to be queried', True], | ||
) -> str: | ||
""" | ||
Get the weather for `city_name` in the following week | ||
""" | ||
... | ||
``` | ||
|
||
 | ||
|
||
此外,你也可以在页面中通过 `Manual mode` 进入手动模式,在这一模式下你可以通过 YAML 来直接指定工具列表,但你需要手动将工具的输出反馈给模型。 | ||
|
||
### 代码解释器模式 | ||
|
||
由于拥有代码执行环境,此模式下的模型能够执行更为复杂的任务,例如绘制图表、执行符号运算等等。模型会根据对任务完成情况的理解自动地连续执行多个代码块,直到任务完成。因此,在这一模式下,你只需要指明希望模型执行的任务即可。 | ||
|
||
例如,我们可以让 ChatGLM3 画一个爱心: | ||
|
||
 | ||
|
||
### 额外技巧 | ||
|
||
- 在模型生成文本时,可以通过页面右上角的 `Stop` 按钮进行打断。 | ||
- 刷新页面即可清空对话记录。 | ||
|
||
# Enjoy! |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
from __future__ import annotations | ||
|
||
from collections.abc import Iterable | ||
import os | ||
from typing import Any, Protocol | ||
|
||
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token | ||
import streamlit as st | ||
import torch | ||
from transformers import AutoModel, AutoTokenizer | ||
|
||
from conversation import Conversation | ||
|
||
TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:' | ||
|
||
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b') | ||
|
||
@st.cache_resource | ||
def get_client() -> Client: | ||
client = HFClient(MODEL_PATH) | ||
return client | ||
|
||
class Client(Protocol): | ||
def generate_stream(self, | ||
system: str | None, | ||
tools: list[dict] | None, | ||
history: list[Conversation], | ||
**parameters: Any | ||
) -> Iterable[TextGenerationStreamResponse]: | ||
... | ||
|
||
def stream_chat(self, tokenizer, query: str, history: list[tuple[str, str]] = None, role: str = "user", | ||
past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, | ||
logits_processor=None, return_past_key_values=False, **kwargs): | ||
|
||
from transformers.generation.logits_process import LogitsProcessor | ||
from transformers.generation.utils import LogitsProcessorList | ||
|
||
class InvalidScoreLogitsProcessor(LogitsProcessor): | ||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | ||
if torch.isnan(scores).any() or torch.isinf(scores).any(): | ||
scores.zero_() | ||
scores[..., 5] = 5e4 | ||
return scores | ||
|
||
if history is None: | ||
history = [] | ||
if logits_processor is None: | ||
logits_processor = LogitsProcessorList() | ||
logits_processor.append(InvalidScoreLogitsProcessor()) | ||
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), | ||
tokenizer.get_command("<|observation|>")] | ||
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, | ||
"temperature": temperature, "logits_processor": logits_processor, **kwargs} | ||
if past_key_values is None: | ||
inputs = tokenizer.build_chat_input(query, history=history, role=role) | ||
else: | ||
inputs = tokenizer.build_chat_input(query, role=role) | ||
inputs = inputs.to(self.device) | ||
if past_key_values is not None: | ||
past_length = past_key_values[0][0].shape[0] | ||
if self.transformer.pre_seq_len is not None: | ||
past_length -= self.transformer.pre_seq_len | ||
inputs.position_ids += past_length | ||
attention_mask = inputs.attention_mask | ||
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) | ||
inputs['attention_mask'] = attention_mask | ||
history.append({"role": role, "content": query}) | ||
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, | ||
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, | ||
**gen_kwargs): | ||
if return_past_key_values: | ||
outputs, past_key_values = outputs | ||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] | ||
response = tokenizer.decode(outputs) | ||
if response and response[-1] != "�": | ||
new_history = history | ||
if return_past_key_values: | ||
yield response, new_history, past_key_values | ||
else: | ||
yield response, new_history | ||
|
||
class HFClient(Client): | ||
def __init__(self, model_path: str): | ||
self.model_path = model_path | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | ||
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to( | ||
'cuda' if torch.cuda.is_available() else | ||
'mps' if torch.backends.mps.is_available() else | ||
'cpu' | ||
) | ||
self.model = self.model.eval() | ||
|
||
def generate_stream(self, | ||
system: str | None, | ||
tools: list[dict] | None, | ||
history: list[Conversation], | ||
**parameters: Any | ||
) -> Iterable[TextGenerationStreamResponse]: | ||
chat_history = [{ | ||
'role': 'system', | ||
'content': system if not tools else TOOL_PROMPT, | ||
}] | ||
|
||
if tools: | ||
chat_history[0]['tools'] = tools | ||
|
||
for conversation in history[:-1]: | ||
chat_history.append({ | ||
'role': str(conversation.role).removeprefix('<|').removesuffix('|>'), | ||
'content': conversation.content, | ||
}) | ||
|
||
query = history[-1].content | ||
role = str(history[-1].role).removeprefix('<|').removesuffix('|>') | ||
|
||
text = '' | ||
|
||
for new_text, _ in stream_chat(self.model, | ||
self.tokenizer, | ||
query, | ||
chat_history, | ||
role, | ||
**parameters, | ||
): | ||
word = new_text.removeprefix(text) | ||
word_stripped = word.strip() | ||
text = new_text | ||
yield TextGenerationStreamResponse( | ||
generated_text=text, | ||
token=Token( | ||
id=0, | ||
logprob=0, | ||
text=word, | ||
special=word_stripped.startswith('<|') and word_stripped.endswith('|>'), | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from dataclasses import dataclass | ||
from enum import auto, Enum | ||
import json | ||
|
||
from PIL.Image import Image | ||
import streamlit as st | ||
from streamlit.delta_generator import DeltaGenerator | ||
|
||
TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:\n' | ||
|
||
class Role(Enum): | ||
SYSTEM = auto() | ||
USER = auto() | ||
ASSISTANT = auto() | ||
TOOL = auto() | ||
INTERPRETER = auto() | ||
OBSERVATION = auto() | ||
|
||
def __str__(self): | ||
match self: | ||
case Role.SYSTEM: | ||
return "<|system|>" | ||
case Role.USER: | ||
return "<|user|>" | ||
case Role.ASSISTANT | Role.TOOL | Role.INTERPRETER: | ||
return "<|assistant|>" | ||
case Role.OBSERVATION: | ||
return "<|observation|>" | ||
|
||
# Get the message block for the given role | ||
def get_message(self): | ||
# Compare by value here, because the enum object in the session state | ||
# is not the same as the enum cases here, due to streamlit's rerunning | ||
# behavior. | ||
match self.value: | ||
case Role.SYSTEM.value: | ||
return | ||
case Role.USER.value: | ||
return st.chat_message(name="user", avatar="user") | ||
case Role.ASSISTANT.value: | ||
return st.chat_message(name="assistant", avatar="assistant") | ||
case Role.TOOL.value: | ||
return st.chat_message(name="tool", avatar="assistant") | ||
case Role.INTERPRETER.value: | ||
return st.chat_message(name="interpreter", avatar="assistant") | ||
case Role.OBSERVATION.value: | ||
return st.chat_message(name="observation", avatar="user") | ||
case _: | ||
st.error(f'Unexpected role: {self}') | ||
|
||
@dataclass | ||
class Conversation: | ||
role: Role | ||
content: str | ||
tool: str | None = None | ||
image: Image | None = None | ||
|
||
def __str__(self) -> str: | ||
print(self.role, self.content, self.tool) | ||
match self.role: | ||
case Role.SYSTEM | Role.USER | Role.ASSISTANT | Role.OBSERVATION: | ||
return f'{self.role}\n{self.content}' | ||
case Role.TOOL: | ||
return f'{self.role}{self.tool}\n{self.content}' | ||
case Role.INTERPRETER: | ||
return f'{self.role}interpreter\n{self.content}' | ||
|
||
# Human readable format | ||
def get_text(self) -> str: | ||
text = postprocess_text(self.content) | ||
match self.role.value: | ||
case Role.TOOL.value: | ||
text = f'Calling tool `{self.tool}`:\n{text}' | ||
case Role.INTERPRETER.value: | ||
text = f'{text}' | ||
case Role.OBSERVATION.value: | ||
text = f'Observation:\n```\n{text}\n```' | ||
return text | ||
|
||
# Display as a markdown block | ||
def show(self, placeholder: DeltaGenerator | None=None) -> str: | ||
if placeholder: | ||
message = placeholder | ||
else: | ||
message = self.role.get_message() | ||
if self.image: | ||
message.image(self.image) | ||
else: | ||
text = self.get_text() | ||
message.markdown(text) | ||
|
||
def preprocess_text( | ||
system: str | None, | ||
tools: list[dict] | None, | ||
history: list[Conversation], | ||
) -> str: | ||
if tools: | ||
tools = json.dumps(tools, indent=4, ensure_ascii=False) | ||
|
||
prompt = f"{Role.SYSTEM}\n" | ||
prompt += system if not tools else TOOL_PROMPT | ||
if tools: | ||
tools = json.loads(tools) | ||
prompt += json.dumps(tools, ensure_ascii=False) | ||
for conversation in history: | ||
prompt += f'{conversation}' | ||
prompt += f'{Role.ASSISTANT}\n' | ||
return prompt | ||
|
||
def postprocess_text(text: str) -> str: | ||
text = text.replace("\(", "$") | ||
text = text.replace("\)", "$") | ||
text = text.replace("\[", "$$") | ||
text = text.replace("\]", "$$") | ||
text = text.replace("<|assistant|>", "") | ||
text = text.replace("<|observation|>", "") | ||
text = text.replace("<|system|>", "") | ||
text = text.replace("<|user|>", "") | ||
return text.strip() |
Oops, something went wrong.