diff --git a/composite_demo b/composite_demo deleted file mode 160000 index 49699259..00000000 --- a/composite_demo +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 49699259b41cc1e453168d89a67cc8c63bd5ca5e diff --git a/composite_demo/.streamlit/config.toml b/composite_demo/.streamlit/config.toml new file mode 100644 index 00000000..30bc5cce --- /dev/null +++ b/composite_demo/.streamlit/config.toml @@ -0,0 +1,2 @@ +[theme] +font = "monospace" \ No newline at end of file diff --git a/composite_demo/README.md b/composite_demo/README.md new file mode 100644 index 00000000..428dd662 --- /dev/null +++ b/composite_demo/README.md @@ -0,0 +1,85 @@ +# ChatGLM3 Web Demo + +![Demo webpage](assets/demo.png) + +## 安装 + +我们建议通过 [Conda](https://docs.conda.io/en/latest/) 进行环境管理。 + +执行以下命令新建一个 conda 环境并安装所需依赖: + +```bash +conda create -n chatglm3-demo python=3.10 +conda activate chatglm3-demo +pip install -r requirements.txt +``` + +请注意,本项目需要 Python 3.10 或更高版本。 + +此外,使用 Code Interpreter 还需要安装 Jupyter 内核: + +```bash +ipython kernel install --name chatglm3-demo --user +``` + +## 运行 + +运行以下命令在本地加载模型并启动 demo: + +```bash +streamlit run main.py +``` + +之后即可从命令行中看到 demo 的地址,点击即可访问。初次访问需要下载并加载模型,可能需要花费一定时间。 + +如果已经在本地下载了模型,可以通过 `export MODEL_PATH=/path/to/model` 来指定从本地加载模型。如果需要自定义 Jupyter 内核,可以通过 `export IPYKERNEL=` 来指定。 + +## 使用 + +ChatGLM3 Demo 拥有三种模式: + +- Chat: 对话模式,在此模式下可以与模型进行对话。 +- Tool: 工具模式,模型除了对话外,还可以通过工具进行其他操作。 +- Code Interpreter: 代码解释器模式,模型可以在一个 Jupyter 环境中执行代码并获取结果,以完成复杂任务。 + +### 对话模式 + +对话模式下,用户可以直接在侧边栏修改 top_p, temperature, System Prompt 等参数来调整模型的行为。例如 + +![The model responses following system prompt](assets/emojis.png) + +### 工具模式 + +可以通过在 `tool_registry.py` 中注册新的工具来增强模型的能力。只需要使用 `@register_tool` 装饰函数即可完成注册。对于工具声明,函数名称即为工具的名称,函数 docstring 即为工具的说明;对于工具的参数,使用 `Annotated[typ: type, description: str, required: bool]` 标注参数的类型、描述和是否必须。 + +例如,`get_weather` 工具的注册如下: + +```python +@register_tool +def get_weather( + city_name: Annotated[str, 'The name of the city to be queried', True], +) -> str: + """ + Get the weather for `city_name` in the following week + """ + ... +``` + +![The model uses tool to query the weather of pairs.](assets/tool.png) + +此外,你也可以在页面中通过 `Manual mode` 进入手动模式,在这一模式下你可以通过 YAML 来直接指定工具列表,但你需要手动将工具的输出反馈给模型。 + +### 代码解释器模式 + +由于拥有代码执行环境,此模式下的模型能够执行更为复杂的任务,例如绘制图表、执行符号运算等等。模型会根据对任务完成情况的理解自动地连续执行多个代码块,直到任务完成。因此,在这一模式下,你只需要指明希望模型执行的任务即可。 + +例如,我们可以让 ChatGLM3 画一个爱心: + +![The code interpreter draws a heart according to the user's instructions.](assets/heart.png) + +### 额外技巧 + +- 在模型生成文本时,可以通过页面右上角的 `Stop` 按钮进行打断。 +- 刷新页面即可清空对话记录。 + +# Enjoy! \ No newline at end of file diff --git a/composite_demo/assets/demo.png b/composite_demo/assets/demo.png new file mode 100644 index 00000000..5d2b10f2 Binary files /dev/null and b/composite_demo/assets/demo.png differ diff --git a/composite_demo/assets/emojis.png b/composite_demo/assets/emojis.png new file mode 100644 index 00000000..9bc293d9 Binary files /dev/null and b/composite_demo/assets/emojis.png differ diff --git a/composite_demo/assets/heart.png b/composite_demo/assets/heart.png new file mode 100644 index 00000000..0ee1462c Binary files /dev/null and b/composite_demo/assets/heart.png differ diff --git a/composite_demo/assets/tool.png b/composite_demo/assets/tool.png new file mode 100644 index 00000000..29bca13e Binary files /dev/null and b/composite_demo/assets/tool.png differ diff --git a/composite_demo/client.py b/composite_demo/client.py new file mode 100644 index 00000000..07f84962 --- /dev/null +++ b/composite_demo/client.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +from collections.abc import Iterable +import os +from typing import Any, Protocol + +from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token +import streamlit as st +import torch +from transformers import AutoModel, AutoTokenizer + +from conversation import Conversation + +TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:' + +MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b') + +@st.cache_resource +def get_client() -> Client: + client = HFClient(MODEL_PATH) + return client + +class Client(Protocol): + def generate_stream(self, + system: str | None, + tools: list[dict] | None, + history: list[Conversation], + **parameters: Any + ) -> Iterable[TextGenerationStreamResponse]: + ... + +def stream_chat(self, tokenizer, query: str, history: list[tuple[str, str]] = None, role: str = "user", + past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, + logits_processor=None, return_past_key_values=False, **kwargs): + + from transformers.generation.logits_process import LogitsProcessor + from transformers.generation.utils import LogitsProcessorList + + class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), + tokenizer.get_command("<|observation|>")] + gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if past_key_values is None: + inputs = tokenizer.build_chat_input(query, history=history, role=role) + else: + inputs = tokenizer.build_chat_input(query, role=role) + inputs = inputs.to(self.device) + if past_key_values is not None: + past_length = past_key_values[0][0].shape[0] + if self.transformer.pre_seq_len is not None: + past_length -= self.transformer.pre_seq_len + inputs.position_ids += past_length + attention_mask = inputs.attention_mask + attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) + inputs['attention_mask'] = attention_mask + history.append({"role": role, "content": query}) + for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, + eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, + **gen_kwargs): + if return_past_key_values: + outputs, past_key_values = outputs + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + if response and response[-1] != "�": + new_history = history + if return_past_key_values: + yield response, new_history, past_key_values + else: + yield response, new_history + +class HFClient(Client): + def __init__(self, model_path: str): + self.model_path = model_path + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to( + 'cuda' if torch.cuda.is_available() else + 'mps' if torch.backends.mps.is_available() else + 'cpu' + ) + self.model = self.model.eval() + + def generate_stream(self, + system: str | None, + tools: list[dict] | None, + history: list[Conversation], + **parameters: Any + ) -> Iterable[TextGenerationStreamResponse]: + chat_history = [{ + 'role': 'system', + 'content': system if not tools else TOOL_PROMPT, + }] + + if tools: + chat_history[0]['tools'] = tools + + for conversation in history[:-1]: + chat_history.append({ + 'role': str(conversation.role).removeprefix('<|').removesuffix('|>'), + 'content': conversation.content, + }) + + query = history[-1].content + role = str(history[-1].role).removeprefix('<|').removesuffix('|>') + + text = '' + + for new_text, _ in stream_chat(self.model, + self.tokenizer, + query, + chat_history, + role, + **parameters, + ): + word = new_text.removeprefix(text) + word_stripped = word.strip() + text = new_text + yield TextGenerationStreamResponse( + generated_text=text, + token=Token( + id=0, + logprob=0, + text=word, + special=word_stripped.startswith('<|') and word_stripped.endswith('|>'), + ) + ) diff --git a/composite_demo/conversation.py b/composite_demo/conversation.py new file mode 100644 index 00000000..1ac13e77 --- /dev/null +++ b/composite_demo/conversation.py @@ -0,0 +1,119 @@ +from dataclasses import dataclass +from enum import auto, Enum +import json + +from PIL.Image import Image +import streamlit as st +from streamlit.delta_generator import DeltaGenerator + +TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:\n' + +class Role(Enum): + SYSTEM = auto() + USER = auto() + ASSISTANT = auto() + TOOL = auto() + INTERPRETER = auto() + OBSERVATION = auto() + + def __str__(self): + match self: + case Role.SYSTEM: + return "<|system|>" + case Role.USER: + return "<|user|>" + case Role.ASSISTANT | Role.TOOL | Role.INTERPRETER: + return "<|assistant|>" + case Role.OBSERVATION: + return "<|observation|>" + + # Get the message block for the given role + def get_message(self): + # Compare by value here, because the enum object in the session state + # is not the same as the enum cases here, due to streamlit's rerunning + # behavior. + match self.value: + case Role.SYSTEM.value: + return + case Role.USER.value: + return st.chat_message(name="user", avatar="user") + case Role.ASSISTANT.value: + return st.chat_message(name="assistant", avatar="assistant") + case Role.TOOL.value: + return st.chat_message(name="tool", avatar="assistant") + case Role.INTERPRETER.value: + return st.chat_message(name="interpreter", avatar="assistant") + case Role.OBSERVATION.value: + return st.chat_message(name="observation", avatar="user") + case _: + st.error(f'Unexpected role: {self}') + +@dataclass +class Conversation: + role: Role + content: str + tool: str | None = None + image: Image | None = None + + def __str__(self) -> str: + print(self.role, self.content, self.tool) + match self.role: + case Role.SYSTEM | Role.USER | Role.ASSISTANT | Role.OBSERVATION: + return f'{self.role}\n{self.content}' + case Role.TOOL: + return f'{self.role}{self.tool}\n{self.content}' + case Role.INTERPRETER: + return f'{self.role}interpreter\n{self.content}' + + # Human readable format + def get_text(self) -> str: + text = postprocess_text(self.content) + match self.role.value: + case Role.TOOL.value: + text = f'Calling tool `{self.tool}`:\n{text}' + case Role.INTERPRETER.value: + text = f'{text}' + case Role.OBSERVATION.value: + text = f'Observation:\n```\n{text}\n```' + return text + + # Display as a markdown block + def show(self, placeholder: DeltaGenerator | None=None) -> str: + if placeholder: + message = placeholder + else: + message = self.role.get_message() + if self.image: + message.image(self.image) + else: + text = self.get_text() + message.markdown(text) + +def preprocess_text( + system: str | None, + tools: list[dict] | None, + history: list[Conversation], +) -> str: + if tools: + tools = json.dumps(tools, indent=4, ensure_ascii=False) + + prompt = f"{Role.SYSTEM}\n" + prompt += system if not tools else TOOL_PROMPT + if tools: + tools = json.loads(tools) + prompt += json.dumps(tools, ensure_ascii=False) + for conversation in history: + prompt += f'{conversation}' + prompt += f'{Role.ASSISTANT}\n' + return prompt + +def postprocess_text(text: str) -> str: + text = text.replace("\(", "$") + text = text.replace("\)", "$") + text = text.replace("\[", "$$") + text = text.replace("\]", "$$") + text = text.replace("<|assistant|>", "") + text = text.replace("<|observation|>", "") + text = text.replace("<|system|>", "") + text = text.replace("<|user|>", "") + return text.strip() \ No newline at end of file diff --git a/composite_demo/demo_chat.py b/composite_demo/demo_chat.py new file mode 100644 index 00000000..e8a2e430 --- /dev/null +++ b/composite_demo/demo_chat.py @@ -0,0 +1,77 @@ +import streamlit as st +from streamlit.delta_generator import DeltaGenerator + +from client import get_client +from conversation import postprocess_text, preprocess_text, Conversation, Role + +MAX_LENGTH = 8192 + +client = get_client() + +# Append a conversation into history, while show it in a new markdown block +def append_conversation( + conversation: Conversation, + history: list[Conversation], + placeholder: DeltaGenerator | None=None, +) -> None: + history.append(conversation) + conversation.show(placeholder) + +def main(top_p: float, temperature: float, system_prompt: str, prompt_text: str): + placeholder = st.empty() + with placeholder.container(): + if 'chat_history' not in st.session_state: + st.session_state.chat_history = [] + + history: list[Conversation] = st.session_state.chat_history + + for conversation in history: + conversation.show() + + if prompt_text: + prompt_text = prompt_text.strip() + append_conversation(Conversation(Role.USER, prompt_text), history) + + input_text = preprocess_text( + system_prompt, + tools=None, + history=history, + ) + print("=== Input:") + print(input_text) + print("=== History:") + print(history) + + placeholder = st.empty() + message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant") + markdown_placeholder = message_placeholder.empty() + + output_text = '' + for response in client.generate_stream( + system_prompt, + tools=None, + history=history, + do_sample=True, + max_length=MAX_LENGTH, + temperature=temperature, + top_p=top_p, + stop_sequences=[str(Role.USER)], + ): + token = response.token + if response.token.special: + print("=== Output:") + print(output_text) + + match token.text.strip(): + case '<|user|>': + break + case _: + st.error(f'Unexpected special token: {token.text.strip()}') + break + output_text += response.token.text + markdown_placeholder.markdown(postprocess_text(output_text + '▌')) + + append_conversation(Conversation( + Role.ASSISTANT, + postprocess_text(output_text), + ), history, markdown_placeholder) \ No newline at end of file diff --git a/composite_demo/demo_ci.py b/composite_demo/demo_ci.py new file mode 100644 index 00000000..23065333 --- /dev/null +++ b/composite_demo/demo_ci.py @@ -0,0 +1,327 @@ +import base64 +from io import BytesIO +import os +from pprint import pprint +import queue +import re +from subprocess import PIPE + +import jupyter_client +from PIL import Image +import streamlit as st +from streamlit.delta_generator import DeltaGenerator + +from client import get_client +from conversation import postprocess_text, preprocess_text, Conversation, Role + +IPYKERNEL = os.environ.get('IPYKERNEL', 'chatglm3-demo') + +SYSTEM_PROMPT = '你是一位智能AI助手,你叫ChatGLM,你连接着一台电脑,但请注意不能联网。在使用Python解决任务时,你可以运行代码并得到结果,如果运行结果有错误,你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件,文件默认存储路径是/mnt/data/。' + +MAX_LENGTH = 8192 +TRUNCATE_LENGTH = 1024 + +client = get_client() + +class CodeKernel(object): + def __init__(self, + kernel_name='kernel', + kernel_id=None, + kernel_config_path="", + python_path=None, + ipython_path=None, + init_file_path="./startup.py", + verbose=1): + + self.kernel_name = kernel_name + self.kernel_id = kernel_id + self.kernel_config_path = kernel_config_path + self.python_path = python_path + self.ipython_path = ipython_path + self.init_file_path = init_file_path + self.verbose = verbose + + if python_path is None and ipython_path is None: + env = None + else: + env = {"PATH": self.python_path + ":$PATH", "PYTHONPATH": self.python_path} + + # Initialize the backend kernel + self.kernel_manager = jupyter_client.KernelManager(kernel_name=IPYKERNEL, + connection_file=self.kernel_config_path, + exec_files=[self.init_file_path], + env=env) + if self.kernel_config_path: + self.kernel_manager.load_connection_file() + self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE) + print("Backend kernel started with the configuration: {}".format( + self.kernel_config_path)) + else: + self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE) + print("Backend kernel started with the configuration: {}".format( + self.kernel_manager.connection_file)) + + if verbose: + pprint(self.kernel_manager.get_connection_info()) + + # Initialize the code kernel + self.kernel = self.kernel_manager.blocking_client() + # self.kernel.load_connection_file() + self.kernel.start_channels() + print("Code kernel started.") + + def execute(self, code): + self.kernel.execute(code) + try: + shell_msg = self.kernel.get_shell_msg(timeout=30) + io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content'] + while True: + msg_out = io_msg_content + ### Poll the message + try: + io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content'] + if 'execution_state' in io_msg_content and io_msg_content['execution_state'] == 'idle': + break + except queue.Empty: + break + + return shell_msg, msg_out + except Exception as e: + print(e) + return None + + def execute_interactive(self, code, verbose=False): + shell_msg = self.kernel.execute_interactive(code) + if shell_msg is queue.Empty: + if verbose: + print("Timeout waiting for shell message.") + self.check_msg(shell_msg, verbose=verbose) + + return shell_msg + + def inspect(self, code, verbose=False): + msg_id = self.kernel.inspect(code) + shell_msg = self.kernel.get_shell_msg(timeout=30) + if shell_msg is queue.Empty: + if verbose: + print("Timeout waiting for shell message.") + self.check_msg(shell_msg, verbose=verbose) + + return shell_msg + + def get_error_msg(self, msg, verbose=False) -> str | None: + if msg['content']['status'] == 'error': + try: + error_msg = msg['content']['traceback'] + except: + try: + error_msg = msg['content']['traceback'][-1].strip() + except: + error_msg = "Traceback Error" + if verbose: + print("Error: ", error_msg) + return error_msg + return None + + def check_msg(self, msg, verbose=False): + status = msg['content']['status'] + if status == 'ok': + if verbose: + print("Execution succeeded.") + elif status == 'error': + for line in msg['content']['traceback']: + if verbose: + print(line) + + def shutdown(self): + # Shutdown the backend kernel + self.kernel_manager.shutdown_kernel() + print("Backend kernel shutdown.") + # Shutdown the code kernel + self.kernel.shutdown() + print("Code kernel shutdown.") + + def restart(self): + # Restart the backend kernel + self.kernel_manager.restart_kernel() + # print("Backend kernel restarted.") + + def interrupt(self): + # Interrupt the backend kernel + self.kernel_manager.interrupt_kernel() + # print("Backend kernel interrupted.") + + def is_alive(self): + return self.kernel.is_alive() + +def b64_2_img(data): + buff = BytesIO(base64.b64decode(data)) + return Image.open(buff) + +def clean_ansi_codes(input_string): + ansi_escape = re.compile(r'(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]') + return ansi_escape.sub('', input_string) + +def execute(code, kernel: CodeKernel) -> tuple[str, str | Image.Image]: + res = "" + res_type = None + code = code.replace("<|observation|>", "") + code = code.replace("<|assistant|>interpreter", "") + code = code.replace("<|assistant|>", "") + code = code.replace("<|user|>", "") + code = code.replace("<|system|>", "") + msg, output = kernel.execute(code) + + if msg['metadata']['status'] == "timeout": + return res_type, 'Timed out' + elif msg['metadata']['status'] == 'error': + return res_type, clean_ansi_codes('\n'.join(kernel.get_error_msg(msg, verbose=True))) + + if 'text' in output: + res_type = "text" + res = output['text'] + elif 'data' in output: + for key in output['data']: + if 'text/plain' in key: + res_type = "text" + res = output['data'][key] + elif 'image/png' in key: + res_type = "image" + res = output['data'][key] + break + + if res_type == "image": + return res_type, b64_2_img(res) + elif res_type == "text" or res_type == "traceback": + res = res + + return res_type, res + +@st.cache_resource +def get_kernel(): + kernel = CodeKernel() + return kernel + +def extract_code(text: str) -> str: + pattern = r'```([^\n]*)\n(.*?)```' + matches = re.findall(pattern, text, re.DOTALL) + return matches[-1][1] + +# Append a conversation into history, while show it in a new markdown block +def append_conversation( + conversation: Conversation, + history: list[Conversation], + placeholder: DeltaGenerator | None=None, +) -> None: + history.append(conversation) + conversation.show(placeholder) + +def main(top_p: float, temperature: float, prompt_text: str): + if 'ci_history' not in st.session_state: + st.session_state.ci_history = [] + + history: list[Conversation] = st.session_state.ci_history + + for conversation in history: + conversation.show() + + if prompt_text: + prompt_text = prompt_text.strip() + role = Role.USER + append_conversation(Conversation(role, prompt_text), history) + + input_text = preprocess_text( + SYSTEM_PROMPT, + None, + history, + ) + print("=== Input:") + print(input_text) + print("=== History:") + print(history) + + placeholder = st.container() + message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant") + markdown_placeholder = message_placeholder.empty() + + for _ in range(5): + output_text = '' + for response in client.generate_stream( + system=SYSTEM_PROMPT, + tools=None, + history=history, + do_sample=True, + max_length=MAX_LENGTH, + temperature=temperature, + top_p=top_p, + stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)], + ): + token = response.token + if response.token.special: + print("=== Output:") + print(output_text) + + match token.text.strip(): + case '<|user|>': + append_conversation(Conversation( + Role.ASSISTANT, + postprocess_text(output_text), + ), history, markdown_placeholder) + return + # Initiate tool call + case '<|assistant|>': + append_conversation(Conversation( + Role.ASSISTANT, + postprocess_text(output_text), + ), history, markdown_placeholder) + message_placeholder = placeholder.chat_message(name="interpreter", avatar="assistant") + markdown_placeholder = message_placeholder.empty() + output_text = '' + continue + case '<|observation|>': + code = extract_code(output_text) + print("Code:", code) + + display_text = output_text.split('interpreter')[-1].strip() + append_conversation(Conversation( + Role.INTERPRETER, + postprocess_text(display_text), + ), history, markdown_placeholder) + message_placeholder = placeholder.chat_message(name="observation", avatar="user") + markdown_placeholder = message_placeholder.empty() + output_text = '' + + with markdown_placeholder: + with st.spinner('Executing code...'): + try: + res_type, res = execute(code, get_kernel()) + except Exception as e: + st.error(f'Error when executing code: {e}') + return + print("Received:", res_type, res) + + if res_type == 'text' and len(res) > TRUNCATE_LENGTH: + res = res[:TRUNCATE_LENGTH] + ' [TRUNCATED]' + + append_conversation(Conversation( + Role.OBSERVATION, + '[Image]' if res_type == 'image' else postprocess_text(res), + tool=None, + image=res if res_type == 'image' else None, + ), history, markdown_placeholder) + message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant") + markdown_placeholder = message_placeholder.empty() + output_text = '' + break + case _: + st.error(f'Unexpected special token: {token.text.strip()}') + break + output_text += response.token.text + display_text = output_text.split('interpreter')[-1].strip() + markdown_placeholder.markdown(postprocess_text(display_text + '▌')) + else: + append_conversation(Conversation( + Role.ASSISTANT, + postprocess_text(output_text), + ), history, markdown_placeholder) + return \ No newline at end of file diff --git a/composite_demo/demo_tool.py b/composite_demo/demo_tool.py new file mode 100644 index 00000000..bef972cf --- /dev/null +++ b/composite_demo/demo_tool.py @@ -0,0 +1,191 @@ +import re +import yaml +from yaml import YAMLError + +import streamlit as st +from streamlit.delta_generator import DeltaGenerator + +from client import get_client +from conversation import postprocess_text, preprocess_text, Conversation, Role +from tool_registry import dispatch_tool, get_tools + +MAX_LENGTH = 8192 +TRUNCATE_LENGTH = 1024 + +EXAMPLE_TOOL = { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + } +} + +client = get_client() + +def tool_call(*args, **kwargs) -> dict: + print("=== Tool call:") + print(args) + print(kwargs) + st.session_state.calling_tool = True + return kwargs + +def yaml_to_dict(tools: str) -> list[dict] | None: + try: + return yaml.safe_load(tools) + except YAMLError: + return None + +def extract_code(text: str) -> str: + pattern = r'```([^\n]*)\n(.*?)```' + matches = re.findall(pattern, text, re.DOTALL) + return matches[-1][1] + +# Append a conversation into history, while show it in a new markdown block +def append_conversation( + conversation: Conversation, + history: list[Conversation], + placeholder: DeltaGenerator | None=None, +) -> None: + history.append(conversation) + conversation.show(placeholder) + +def main(top_p: float, temperature: float, prompt_text: str): + manual_mode = st.toggle('Manual mode', + help='Define your tools in YAML format. You need to supply tool call results manually.' + ) + + if manual_mode: + with st.expander('Tools'): + tools = st.text_area( + 'Define your tools in YAML format here:', + yaml.safe_dump([EXAMPLE_TOOL], sort_keys=False), + height=400, + ) + tools = yaml_to_dict(tools) + + if not tools: + st.error('YAML format error in tools definition') + else: + tools = get_tools() + + if 'tool_history' not in st.session_state: + st.session_state.tool_history = [] + if 'calling_tool' not in st.session_state: + st.session_state.calling_tool = False + + history: list[Conversation] = st.session_state.tool_history + + for conversation in history: + conversation.show() + + if prompt_text: + prompt_text = prompt_text.strip() + role = st.session_state.calling_tool and Role.OBSERVATION or Role.USER + append_conversation(Conversation(role, prompt_text), history) + st.session_state.calling_tool = False + + input_text = preprocess_text( + None, + tools, + history, + ) + print("=== Input:") + print(input_text) + print("=== History:") + print(history) + + placeholder = st.container() + message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant") + markdown_placeholder = message_placeholder.empty() + + for _ in range(5): + output_text = '' + for response in client.generate_stream( + system=None, + tools=tools, + history=history, + do_sample=True, + max_length=MAX_LENGTH, + temperature=temperature, + top_p=top_p, + stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)], + ): + token = response.token + if response.token.special: + print("=== Output:") + print(output_text) + + match token.text.strip(): + case '<|user|>': + append_conversation(Conversation( + Role.ASSISTANT, + postprocess_text(output_text), + ), history, markdown_placeholder) + return + # Initiate tool call + case '<|assistant|>': + append_conversation(Conversation( + Role.ASSISTANT, + postprocess_text(output_text), + ), history, markdown_placeholder) + output_text = '' + message_placeholder = placeholder.chat_message(name="tool", avatar="assistant") + markdown_placeholder = message_placeholder.empty() + continue + case '<|observation|>': + tool, *output_text = output_text.strip().split('\n') + output_text = '\n'.join(output_text) + + append_conversation(Conversation( + Role.TOOL, + postprocess_text(output_text), + tool, + ), history, markdown_placeholder) + message_placeholder = placeholder.chat_message(name="observation", avatar="user") + markdown_placeholder = message_placeholder.empty() + + try: + code = extract_code(output_text) + args = eval(code, {'tool_call': tool_call}, {}) + except: + st.error('Failed to parse tool call') + return + + output_text = '' + + if manual_mode: + st.info('Please provide tool call results below:') + return + else: + with markdown_placeholder: + with st.spinner(f'Calling tool {tool}...'): + observation = dispatch_tool(tool, args) + + if len(observation) > TRUNCATE_LENGTH: + observation = observation[:TRUNCATE_LENGTH] + ' [TRUNCATED]' + append_conversation(Conversation( + Role.OBSERVATION, observation + ), history, markdown_placeholder) + message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant") + markdown_placeholder = message_placeholder.empty() + st.session_state.calling_tool = False + break + case _: + st.error(f'Unexpected special token: {token.text.strip()}') + return + output_text += response.token.text + markdown_placeholder.markdown(postprocess_text(output_text + '▌')) + else: + append_conversation(Conversation( + Role.ASSISTANT, + postprocess_text(output_text), + ), history, markdown_placeholder) + return diff --git a/composite_demo/main.py b/composite_demo/main.py new file mode 100644 index 00000000..c144cfbf --- /dev/null +++ b/composite_demo/main.py @@ -0,0 +1,56 @@ +from enum import Enum +import streamlit as st + +st.set_page_config( + page_title="ChatGLM3 Demo", + page_icon=":robot:", + layout='centered', + initial_sidebar_state='expanded', +) + +import demo_chat, demo_ci, demo_tool + +DEFAULT_SYSTEM_PROMPT = ''' +You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown. +'''.strip() + +class Mode(str, Enum): + CHAT, TOOL, CI = '💬 Chat', '🛠️ Tool', '🧑‍💻 Code Interpreter' + + +with st.sidebar: + top_p = st.slider( + 'top_p', 0.0, 1.0, 0.8, step=0.01 + ) + temperature = st.slider( + 'temperature', 0.0, 1.5, 0.95, step=0.01 + ) + system_prompt = st.text_area( + label="System Prompt (Only for chat mode)", + height=300, + value=DEFAULT_SYSTEM_PROMPT, + ) + +st.title("ChatGLM3 Demo") + +prompt_text = st.chat_input( + 'Chat with ChatGLM3!', + key='chat_input', +) + +tab = st.radio( + 'Mode', + [mode.value for mode in Mode], + horizontal=True, + label_visibility='hidden', +) + +match tab: + case Mode.CHAT: + demo_chat.main(top_p, temperature, system_prompt, prompt_text) + case Mode.TOOL: + demo_tool.main(top_p, temperature, prompt_text) + case Mode.CI: + demo_ci.main(top_p, temperature, prompt_text) + case _: + st.error(f'Unexpected tab: {tab}') diff --git a/composite_demo/requirements.txt b/composite_demo/requirements.txt new file mode 100644 index 00000000..27f603d0 --- /dev/null +++ b/composite_demo/requirements.txt @@ -0,0 +1,12 @@ +huggingface_hub +ipykernel +ipython +jupyter_client +pillow +sentencepiece +streamlit +tokenizers +torch +transformers +pyyaml +requests \ No newline at end of file diff --git a/composite_demo/tool_registry.py b/composite_demo/tool_registry.py new file mode 100644 index 00000000..e54564c5 --- /dev/null +++ b/composite_demo/tool_registry.py @@ -0,0 +1,109 @@ +from copy import deepcopy +import inspect +from pprint import pformat +import traceback +from types import GenericAlias +from typing import get_origin, Annotated + +_TOOL_HOOKS = {} +_TOOL_DESCRIPTIONS = {} + +def register_tool(func: callable): + tool_name = func.__name__ + tool_description = inspect.getdoc(func).strip() + python_params = inspect.signature(func).parameters + tool_params = [] + for name, param in python_params.items(): + annotation = param.annotation + if annotation is inspect.Parameter.empty: + raise TypeError(f"Parameter `{name}` missing type annotation") + if get_origin(annotation) != Annotated: + raise TypeError(f"Annotation type for `{name}` must be typing.Annotated") + + typ, (description, required) = annotation.__origin__, annotation.__metadata__ + typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__ + if not isinstance(description, str): + raise TypeError(f"Description for `{name}` must be a string") + if not isinstance(required, bool): + raise TypeError(f"Required for `{name}` must be a bool") + + tool_params.append({ + "name": name, + "description": description, + "type": typ, + "required": required + }) + tool_def = { + "name": tool_name, + "description": tool_description, + "params": tool_params + } + + print("[registered tool] " + pformat(tool_def)) + _TOOL_HOOKS[tool_name] = func + _TOOL_DESCRIPTIONS[tool_name] = tool_def + + return func + +def dispatch_tool(tool_name: str, tool_params: dict) -> str: + if tool_name not in _TOOL_HOOKS: + return f"Tool `{tool_name}` not found. Please use a provided tool." + tool_call = _TOOL_HOOKS[tool_name] + try: + ret = tool_call(**tool_params) + except: + ret = traceback.format_exc() + return str(ret) + +def get_tools() -> dict: + return deepcopy(_TOOL_DESCRIPTIONS) + +# Tool Definitions + +@register_tool +def random_number_generator( + seed: Annotated[int, 'The random seed used by the generator', True], + range: Annotated[tuple[int, int], 'The range of the generated numbers', True], +) -> int: + """ + Generates a random number x, s.t. range[0] <= x < range[1] + """ + if not isinstance(seed, int): + raise TypeError("Seed must be an integer") + if not isinstance(range, tuple): + raise TypeError("Range must be a tuple") + if not isinstance(range[0], int) or not isinstance(range[1], int): + raise TypeError("Range must be a tuple of integers") + + import random + return random.Random(seed).randint(*range) + +@register_tool +def get_weather( + city_name: Annotated[str, 'The name of the city to be queried', True], +) -> str: + """ + Get the current weather for `city_name` + """ + + if not isinstance(city_name, str): + raise TypeError("City name must be a string") + + key_selection = { + "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"], + } + import requests + try: + resp = requests.get(f"https://wttr.in/{city_name}?format=j1") + resp.raise_for_status() + resp = resp.json() + ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()} + except: + import traceback + ret = "Error encountered while fetching weather data!\n" + traceback.format_exc() + + return str(ret) + +if __name__ == "__main__": + print(dispatch_tool("get_weather", {"city_name": "beijing"})) + print(get_tools())