-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Xuefei Ning <[email protected]> Co-authored-by: Zixuan Zhou <[email protected]> Co-authored-by: Zifu Wang <[email protected]>
- Loading branch information
Showing
74 changed files
with
12,592 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
results | ||
.DS_Store | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
cover/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
.pybuilder/ | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
# For a library or package, you might want to ignore these files since the code is | ||
# intended to run in multiple environments; otherwise, check them in: | ||
# .python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# poetry | ||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | ||
# This is especially recommended for binary packages to ensure reproducibility, and is more | ||
# commonly ignored for libraries. | ||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | ||
#poetry.lock | ||
|
||
# pdm | ||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. | ||
#pdm.lock | ||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it | ||
# in version control. | ||
# https://pdm.fming.dev/#use-with-ide | ||
.pdm.toml | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# pytype static type analyzer | ||
.pytype/ | ||
|
||
# Cython debug symbols | ||
cython_debug/ | ||
|
||
# PyCharm | ||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can | ||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | ||
# and can be added to the global gitignore or merged into this file. For a more nuclear | ||
# option (not recommended) you can uncomment the following to ignore the entire idea folder. | ||
#.idea/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
repos: | ||
- repo: https://github.com/psf/black | ||
rev: 23.3.0 | ||
hooks: | ||
- id: black | ||
args: [--preview] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) [2023] [Xuefei Ning, Zinan Lin, Zixuan Zhou, Zifu Wang, Huazhong Yang, Yu Wang] | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Skeleton-of-Thought: Large Language Models Can Do Parallel Decoding | ||
|
||
**[[website](https://sites.google.com/view/sot-llm/home)]** | ||
**[[paper](https://arxiv.org/abs/2307.15337)]** | ||
**[[code](https://github.com/imagination-research/sot)]** | ||
|
||
This work aims at decreasing the end-to-end generation latency of large language models (LLMs). One of the major causes of the high generation latency is the sequential decoding approach adopted by almost all state-of-the-art LLMs. In this work, motivated by the thinking and writing process of humans, we propose Skeleton-of-Thought (SoT), which first guides LLMs to generate the skeleton of the answer, and then conducts parallel API calls or batched decoding to complete the contents of each skeleton point in parallel. Not only does SoT provide considerable speed-ups across 12 LLMs, but it can also potentially improve the answer quality on several question categories. To make the overall solution more practical, an extension, SoT with Router (SoT-R), employs a GPT-4-prompting router or a trained RoBERTa router to only trigger SoT for suitable questions. SoT is an initial attempt at data-centric optimization for inference efficiency, and further underscores the potential of pushing LLMs to think more like a human for answer quality. | ||
|
||
|
||
If you find this repository or paper useful, you can cite | ||
``` | ||
@misc{ning2023skeletonofthought, | ||
title={Skeleton-of-Thought: Large Language Models Can Do Parallel Decoding}, | ||
author={Xuefei Ning and Zinan Lin and Zixuan Zhou and Zifu Wang and Huazhong Yang and Yu Wang}, | ||
year={2023}, | ||
eprint={2307.15337}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CL} | ||
} | ||
``` | ||
|
||
The repo is organized as follows. | ||
* The SoT implementation is under [`sot/`](sot/). | ||
* The SoT prompts are given under [`prompts/`](prompts/). For example, `sot_opensource.json` is used for all open-source models, and `sot_gpt4` is used for the GPT-4 API. | ||
* The processed data are under [`data/`](data/). | ||
* The scripts under [`scripts/`](scripts/) are used to dump and evaluate the results. | ||
* The Gradio demo code is under [`demo/`](demo/). The demo is built based on the FastChat demo code. | ||
|
||
## Contents | ||
- [Install](#install) | ||
- [Test SoT with Gradio Demo](#test-sot-with-gradio-demo) | ||
- [Evaluate SoT](#evaluate-sot) | ||
- [Develop SoT](#develop-sot) | ||
- [Acknowledgement](#acknowledgement) | ||
|
||
|
||
## Install | ||
```pip install -e .``` | ||
|
||
We recommend using Python 3.8 to 3.10. | ||
|
||
Some required environment variables/setups: | ||
* Before running the open-source models, please log in to huggingface through `huggingface-cli login` so that the models can be downloaded automatically. | ||
* For GPT4, the script by default uses **OpenAI API**. The API key should be provided by `export OPENAI_API_KEY=<API key>`. | ||
* For GPT-3.5, the script by default uses **Azure OpenAI API**. The API key, engine, and API base should be provided by `export OPENAI_API_KEY=<API key>`, `export ENGINE=<engine>`, and `export API_BASE=<API base>`. | ||
> Note that GPT-4 can also use **Azure OpenAI API**, and GPT-3.5 can also use **OpenAI API**, by modifying the command line arguments accordingly. | ||
* For Claude, please refer to [Claude setup guide](claude_setup_guide.md). | ||
|
||
## Test SoT with Gradio Demo | ||
The SoT gradio demo can be started as follows (under the [`demo/`](demo/) directory): | ||
|
||
1. Launch the controller | ||
``` | ||
python controller.py | ||
``` | ||
2. Launch the model workers | ||
- Lauch a model worker that conducts normal decoding on GPU 0. | ||
``` | ||
CUDA_VISIBLE_DEVICES=0 python model_worker.py --model-path ${MODEL_NAME} --controller http://0.0.0.0:21001 --port 31000 --worker http://0.0.0.0:31000 | ||
``` | ||
- Lauch a model worker that conducts SoT-R decoding on GPU 1. | ||
``` | ||
CUDA_VISIBLE_DEVICES=1 python model_worker.py --model-path ${MODEL_NAME} --controller http://0.0.0.0:21001 --port 31001 --worker http://0.0.0.0:31001 --sot ../prompts/sot_opensource.json --sotr ${ROUTER_MODEL} | ||
``` | ||
The trained router model can be downloaded from [this Google Drive](https://drive.google.com/file/d/1LxEsH9NFwj41wBz8tnT_hwn5LbW7aaL5/view?usp=sharing). | ||
- Note that we recommend directly using SoT-R instead of the plain SoT. But if one wants to trigger SoT for all questions, he or she can launch another model worker as follows: | ||
``` | ||
CUDA_VISIBLE_DEVICES=1 python model_worker.py --model-path ${MODEL_NAME} --controller http://0.0.0.0:21001 --port 31002 --worker http://0.0.0.0:31002 --sot ../prompts/sot_opensource.json | ||
``` | ||
3. Launch the Gradio web demo | ||
``` | ||
python gradio_web_server_multi.py | ||
``` | ||
## Evaluate SoT | ||
### Prepare the dataset | ||
Vicuna-80, WizardLM, and LIMA data is provided under [`data/`](data/) and is ready to use. The pre-processing scripts for getting the data are also attached (`create_dataset.py` in each folder) for reference. | ||
### Dump the answers of SoT and Normal decoding | ||
We put the answer dumping scripts for the Vicuna-80 and WizardLM datasets under [`scripts/vicuna/dump/`](scripts/vicuna/dump/) and [`scripts/wizardlm/dump/`](scripts/wizardlm/dump/). | ||
For example, to dump SoT answers of all open-source models, we can run | ||
``` | ||
python scripts/vicuna/dump/opensource_outline.py | ||
``` | ||
To dump the normal sequential decoding answers of GPT-3.5, we can run | ||
``` | ||
./scripts/vicuna/dump/gpt3.5_naive.sh | ||
``` | ||
### Evaluate the answer quality | ||
We put the evaluation scripts for the Vicuna-80 and WizardLM datasets under [`scripts/vicuna/evaluate/`](scripts/vicuna/evaluate/) and [`scripts/wizardlm/evaluate/`](scripts/wizardlm/evaluate/). | ||
The evaluation scripts use the comparison prompts provided by Fastchat or LLMZoo to prompt a GPT-4 judge to compare the quality of two answers. Please provide OpenAI API key by `export OPENAI_API_KEY=<API key>` before running the scripts. | ||
> Note: We did not use the system prompt except for the LLaMA-2 models when conducting open-source model evaluation in our paper (for both normal decoding and SoT decoding). This is because we use an [older FastChat version](https://github.com/lm-sys/FastChat/tree/f1f2294a66956b340c577fab2751d86f45e71099) for the evaluation in the paper. As our code removes the template messages in the conversation template before querying the model, the system prompt will be removed as template messages in the old FastChat version. Nevertheless, in this code repository, we use a newer version of FastChat (v0.2.26). Consequently, running SoT with the current code will use the system prompt for all open-source models. | ||
## Develop SoT | ||
### Manually tune the SoT prompts | ||
`sot/prompt_eng_main.py` is a helper program to ease manual prompt tuning. Use `bash scripts/debug_prompt.sh <model name or path>` to run the script. This will pop an interactive session in which you can run the following commands: | ||
1. `use <data filepath>` to load data (default: `data/vicuna/data.csv`) | ||
2. `useprompt <prompt filepath>` to change the SoT prompt templates (default: `prompts/sot_opensource.json`) | ||
3. `usenaiveprompt <prompt filepath>` to change the normal prompt template (default to use only the question) | ||
4. (1) `test <ind>` to test SoT decoding for the ind-th question; (2) `test naive <ind>` to test normal decoding; (3) `test batch_outline <ind>` to test SoT decoding with batched point expansion. | ||
* The model outputs will be streamed onto the console (by enabling `--stream` argument to `sot/prompt_eng_main.py`). Note that when using `test <ind>`, the expansion of multiple points is conducted sequentially. When using `test batch_outline <ind>`, the expansion of multiple points is conducted with batch inference, but we do not support streaming the parallel expansion outputs to the console (to check the streaming effect, use the Gradio Web Demo), so one have to wait until the point-expanding completion to see the results. | ||
* After the completion, statistics will also be printed. | ||
* At any time during the generation, one can push Ctrl+C to abort the generation to go back to the interactive session. | ||
5. `exit` to exit the session | ||
> Note: | ||
> 1. We mainly use this program to help engineer the prompt for the open-source models. | ||
> 2. Any other command-line arguments for the model can be fed as the arguments to this script. For example, as testing a 13B model on RTX 3090 with FP16 inference requires two GPUs, we can run | ||
> ```bash scripts/debug_prompt.sh meta-llama/Llama-2-13b-chat-hf --num-gpus 2``` | ||
### Train the router for SoT-R | ||
Preprocess router data and train the RoBERTa router as follows (scripts in [sot/train/](sot/train/)): | ||
1. Preprocess the router data for Vicuna-80, WizardLM, and LIMA: | ||
``` | ||
python offline_prepare_router_data.py \ | ||
--data_path "../../data/lima/router.csv" \ | ||
--output_data_path "lima_router.pkl" | ||
``` | ||
2. Train the router on LIMA and test on Vicuna-80 and WizardLM: | ||
``` | ||
python train_router.py | ||
``` | ||
The predicted results will be saved as `vicuna_router_pred.csv` and `wizardlm_router_pred.csv`. | ||
Our trained router model can be found on [this Google Drive](https://drive.google.com/file/d/1LxEsH9NFwj41wBz8tnT_hwn5LbW7aaL5/view?usp=sharing). | ||
Our manual labels of whether each question should use SoT are provided in `data/*/router.csv`. | ||
## Acknowledgement | ||
During the development of SoT, we use and refer to the amazing work of [FastChat](https://github.com/lm-sys/FastChat) and [Hugging Face transformer package](https://github.com/huggingface/transformers/). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
This page contains instructions about how to setup Claude over Slack for the evaluation. In essence, we will add Claude app into your Slack workspace, and set up another Slack app which the scripts will use to interact with the Claude app. | ||
|
||
Please follow these steps: | ||
|
||
1. Log in a *paid* slack workspace on browser. | ||
1. Open [https://www.anthropic.com/claude-in-slack](https://www.anthropic.com/claude-in-slack) and click "Add to Slack" so as to add Claude app into the Slack workspace. | ||
1. Open [https://api.slack.com/apps](https://api.slack.com/apps) and click "Create New App" to create a Slack app. | ||
1. Open "OAuth & Permissions" tab, and in "User Token Scopes" add the following permissions: | ||
* admin | ||
* channels:history | ||
* channels:read | ||
* channels:write | ||
* chat:write | ||
* groups:history | ||
* groups:read | ||
* groups:write | ||
* im:history | ||
* im:read | ||
* im:write | ||
* mpim:history | ||
* mpim:read | ||
* mpim:write | ||
* users:read | ||
1. Click "Reinstall to Workspace" so that the permission changes are applied | ||
1. Copy your "OAuth Tokens for Your Workspace" starting with "xoxp-" and execute `export SLACK_USER_TOKEN=<token>` in the command line. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import os | ||
import csv | ||
import json | ||
from tqdm import tqdm | ||
from huggingface_hub import login, hf_hub_download | ||
|
||
|
||
JSONL_FILE = "train.jsonl" | ||
JSONL_URL = "GAIR/lima" | ||
DATA_FILE = "data.csv" | ||
|
||
ACCESS_TOKEN = os.getenv("ACCESS_TOKEN") | ||
|
||
|
||
def download(url, file, access_token): | ||
login(access_token) | ||
hf_hub_download(repo_id=url, filename=file, repo_type="dataset", local_dir=".") | ||
|
||
|
||
def save(file, request): | ||
with open(file, "a") as f: | ||
writer = csv.writer(f) | ||
writer.writerow([request]) | ||
|
||
|
||
if __name__ == "__main__": | ||
if not os.path.exists(JSONL_FILE): | ||
download(JSONL_URL, JSONL_FILE, ACCESS_TOKEN) | ||
|
||
save(DATA_FILE, "request") | ||
with open(JSONL_FILE) as f: | ||
for line in tqdm(f): | ||
request = line.strip() | ||
request = json.loads(request)["conversations"][0] | ||
|
||
save(DATA_FILE, request) |
Oops, something went wrong.