Skip to content

Commit

Permalink
Add composite demo (directory)
Browse files Browse the repository at this point in the history
  • Loading branch information
abmfy committed Oct 28, 2023
1 parent 6a47566 commit f38b941
Show file tree
Hide file tree
Showing 15 changed files with 1,115 additions and 1 deletion.
1 change: 0 additions & 1 deletion composite_demo
Submodule composite_demo deleted from 496992
2 changes: 2 additions & 0 deletions composite_demo/.streamlit/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[theme]
font = "monospace"
85 changes: 85 additions & 0 deletions composite_demo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# ChatGLM3 Web Demo

![Demo webpage](assets/demo.png)

## 安装

我们建议通过 [Conda](https://docs.conda.io/en/latest/) 进行环境管理。

执行以下命令新建一个 conda 环境并安装所需依赖:

```bash
conda create -n chatglm3-demo python=3.10
conda activate chatglm3-demo
pip install -r requirements.txt
```

请注意,本项目需要 Python 3.10 或更高版本。

此外,使用 Code Interpreter 还需要安装 Jupyter 内核:

```bash
ipython kernel install --name chatglm3-demo --user
```

## 运行

运行以下命令在本地加载模型并启动 demo:

```bash
streamlit run main.py
```

之后即可从命令行中看到 demo 的地址,点击即可访问。初次访问需要下载并加载模型,可能需要花费一定时间。

如果已经在本地下载了模型,可以通过 `export MODEL_PATH=/path/to/model` 来指定从本地加载模型。如果需要自定义 Jupyter 内核,可以通过 `export IPYKERNEL=<kernel_name>` 来指定。

## 使用

ChatGLM3 Demo 拥有三种模式:

- Chat: 对话模式,在此模式下可以与模型进行对话。
- Tool: 工具模式,模型除了对话外,还可以通过工具进行其他操作。
- Code Interpreter: 代码解释器模式,模型可以在一个 Jupyter 环境中执行代码并获取结果,以完成复杂任务。

### 对话模式

对话模式下,用户可以直接在侧边栏修改 top_p, temperature, System Prompt 等参数来调整模型的行为。例如

![The model responses following system prompt](assets/emojis.png)

### 工具模式

可以通过在 `tool_registry.py` 中注册新的工具来增强模型的能力。只需要使用 `@register_tool` 装饰函数即可完成注册。对于工具声明,函数名称即为工具的名称,函数 docstring 即为工具的说明;对于工具的参数,使用 `Annotated[typ: type, description: str, required: bool]` 标注参数的类型、描述和是否必须。

例如,`get_weather` 工具的注册如下:

```python
@register_tool
def get_weather(
city_name: Annotated[str, 'The name of the city to be queried', True],
) -> str:
"""
Get the weather for `city_name` in the following week
"""
...
```

![The model uses tool to query the weather of pairs.](assets/tool.png)

此外,你也可以在页面中通过 `Manual mode` 进入手动模式,在这一模式下你可以通过 YAML 来直接指定工具列表,但你需要手动将工具的输出反馈给模型。

### 代码解释器模式

由于拥有代码执行环境,此模式下的模型能够执行更为复杂的任务,例如绘制图表、执行符号运算等等。模型会根据对任务完成情况的理解自动地连续执行多个代码块,直到任务完成。因此,在这一模式下,你只需要指明希望模型执行的任务即可。

例如,我们可以让 ChatGLM3 画一个爱心:

![The code interpreter draws a heart according to the user's instructions.](assets/heart.png)

### 额外技巧

- 在模型生成文本时,可以通过页面右上角的 `Stop` 按钮进行打断。
- 刷新页面即可清空对话记录。

# Enjoy!
Binary file added composite_demo/assets/demo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added composite_demo/assets/emojis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added composite_demo/assets/heart.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added composite_demo/assets/tool.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
137 changes: 137 additions & 0 deletions composite_demo/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from __future__ import annotations

from collections.abc import Iterable
import os
from typing import Any, Protocol

from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token
import streamlit as st
import torch
from transformers import AutoModel, AutoTokenizer

from conversation import Conversation

TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:'

MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')

@st.cache_resource
def get_client() -> Client:
client = HFClient(MODEL_PATH)
return client

class Client(Protocol):
def generate_stream(self,
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
**parameters: Any
) -> Iterable[TextGenerationStreamResponse]:
...

def stream_chat(self, tokenizer, query: str, history: list[tuple[str, str]] = None, role: str = "user",
past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
logits_processor=None, return_past_key_values=False, **kwargs):

from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList

class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores

if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|observation|>")]
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
if past_key_values is None:
inputs = tokenizer.build_chat_input(query, history=history, role=role)
else:
inputs = tokenizer.build_chat_input(query, role=role)
inputs = inputs.to(self.device)
if past_key_values is not None:
past_length = past_key_values[0][0].shape[0]
if self.transformer.pre_seq_len is not None:
past_length -= self.transformer.pre_seq_len
inputs.position_ids += past_length
attention_mask = inputs.attention_mask
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
inputs['attention_mask'] = attention_mask
history.append({"role": role, "content": query})
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
**gen_kwargs):
if return_past_key_values:
outputs, past_key_values = outputs
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs)
if response and response[-1] != "�":
new_history = history
if return_past_key_values:
yield response, new_history, past_key_values
else:
yield response, new_history

class HFClient(Client):
def __init__(self, model_path: str):
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(
'cuda' if torch.cuda.is_available() else
'mps' if torch.backends.mps.is_available() else
'cpu'
)
self.model = self.model.eval()

def generate_stream(self,
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
**parameters: Any
) -> Iterable[TextGenerationStreamResponse]:
chat_history = [{
'role': 'system',
'content': system if not tools else TOOL_PROMPT,
}]

if tools:
chat_history[0]['tools'] = tools

for conversation in history[:-1]:
chat_history.append({
'role': str(conversation.role).removeprefix('<|').removesuffix('|>'),
'content': conversation.content,
})

query = history[-1].content
role = str(history[-1].role).removeprefix('<|').removesuffix('|>')

text = ''

for new_text, _ in stream_chat(self.model,
self.tokenizer,
query,
chat_history,
role,
**parameters,
):
word = new_text.removeprefix(text)
word_stripped = word.strip()
text = new_text
yield TextGenerationStreamResponse(
generated_text=text,
token=Token(
id=0,
logprob=0,
text=word,
special=word_stripped.startswith('<|') and word_stripped.endswith('|>'),
)
)
119 changes: 119 additions & 0 deletions composite_demo/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from dataclasses import dataclass
from enum import auto, Enum
import json

from PIL.Image import Image
import streamlit as st
from streamlit.delta_generator import DeltaGenerator

TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:\n'

class Role(Enum):
SYSTEM = auto()
USER = auto()
ASSISTANT = auto()
TOOL = auto()
INTERPRETER = auto()
OBSERVATION = auto()

def __str__(self):
match self:
case Role.SYSTEM:
return "<|system|>"
case Role.USER:
return "<|user|>"
case Role.ASSISTANT | Role.TOOL | Role.INTERPRETER:
return "<|assistant|>"
case Role.OBSERVATION:
return "<|observation|>"

# Get the message block for the given role
def get_message(self):
# Compare by value here, because the enum object in the session state
# is not the same as the enum cases here, due to streamlit's rerunning
# behavior.
match self.value:
case Role.SYSTEM.value:
return
case Role.USER.value:
return st.chat_message(name="user", avatar="user")
case Role.ASSISTANT.value:
return st.chat_message(name="assistant", avatar="assistant")
case Role.TOOL.value:
return st.chat_message(name="tool", avatar="assistant")
case Role.INTERPRETER.value:
return st.chat_message(name="interpreter", avatar="assistant")
case Role.OBSERVATION.value:
return st.chat_message(name="observation", avatar="user")
case _:
st.error(f'Unexpected role: {self}')

@dataclass
class Conversation:
role: Role
content: str
tool: str | None = None
image: Image | None = None

def __str__(self) -> str:
print(self.role, self.content, self.tool)
match self.role:
case Role.SYSTEM | Role.USER | Role.ASSISTANT | Role.OBSERVATION:
return f'{self.role}\n{self.content}'
case Role.TOOL:
return f'{self.role}{self.tool}\n{self.content}'
case Role.INTERPRETER:
return f'{self.role}interpreter\n{self.content}'

# Human readable format
def get_text(self) -> str:
text = postprocess_text(self.content)
match self.role.value:
case Role.TOOL.value:
text = f'Calling tool `{self.tool}`:\n{text}'
case Role.INTERPRETER.value:
text = f'{text}'
case Role.OBSERVATION.value:
text = f'Observation:\n```\n{text}\n```'
return text

# Display as a markdown block
def show(self, placeholder: DeltaGenerator | None=None) -> str:
if placeholder:
message = placeholder
else:
message = self.role.get_message()
if self.image:
message.image(self.image)
else:
text = self.get_text()
message.markdown(text)

def preprocess_text(
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
) -> str:
if tools:
tools = json.dumps(tools, indent=4, ensure_ascii=False)

prompt = f"{Role.SYSTEM}\n"
prompt += system if not tools else TOOL_PROMPT
if tools:
tools = json.loads(tools)
prompt += json.dumps(tools, ensure_ascii=False)
for conversation in history:
prompt += f'{conversation}'
prompt += f'{Role.ASSISTANT}\n'
return prompt

def postprocess_text(text: str) -> str:
text = text.replace("\(", "$")
text = text.replace("\)", "$")
text = text.replace("\[", "$$")
text = text.replace("\]", "$$")
text = text.replace("<|assistant|>", "")
text = text.replace("<|observation|>", "")
text = text.replace("<|system|>", "")
text = text.replace("<|user|>", "")
return text.strip()
Loading

0 comments on commit f38b941

Please sign in to comment.