Skip to content

Commit

Permalink
VLLM Compatibility Update (#222)
Browse files Browse the repository at this point in the history
* Upgrade to latest vllm

* pin numpy/pandas

* pin vllm/torch

* separate requirements_test.txt

* fix sqlite3

* fix tests add more deps

* add FAC to column_ner

* pin spacy

* fix entities
  • Loading branch information
wongjingping authored Nov 1, 2024
1 parent 2b25e9c commit 0262d0e
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 13 deletions.
11 changes: 6 additions & 5 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,22 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- uses: psf/black@stable
test:
runs-on: ubuntu-latest
needs: lint
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.11'
python-version: '3.10'
cache: 'pip'
- name: Install pip dependencies
run: |
pip install -r requirements.txt
pip install --upgrade pip setuptools
pip install -r requirements_test.txt
pip install pytest
- name: Download spaCy model
run: python -m spacy download en_core_web_sm
Expand Down
9 changes: 5 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ argparse
func_timeout
mistralai
mysql-connector-python
numpy==2.1.2
openai>=1.1.0
pandas
pandas==2.2.3
pandas-gbq
peft
psycopg2-binary
Expand All @@ -15,11 +16,11 @@ sentence-transformers
snowflake-connector-python
spacy
sqlalchemy
tiktoken==0.7.0
tiktoken
together
torch
torch==2.4.0
tqdm
transformers
sqlparse
sqlglot
vllm; sys_platform != 'darwin'
vllm==0.6.3.post1; sys_platform != 'darwin'
13 changes: 13 additions & 0 deletions requirements_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
func_timeout
numpy
openai
pandas
psycopg2-binary
pysqlite3
sentence_transformers
snowflake-connector-python
spacy==3.7.2
sqlalchemy
sqlglot
torch
tqdm
1 change: 0 additions & 1 deletion run_model_cot.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ for model_name in "${model_names[@]}"; do
--api_url "http://localhost:${PORT}/generate" \
--api_type "vllm" \
-p 10 \
--cot_table_alias "prealias" \
--logprobs
# finally, kill the api server
pkill -9 -f "python3 utils/api_server.py.*--port ${PORT}"
Expand Down
5 changes: 5 additions & 0 deletions tests/test_utils_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def test_metadata():
"airport.airport_name,text,name of airport",
"flight.airport_name,text,name of the airport",
],
"FAC": [
"country.name,text,country name",
"airport.airport_name,text,name of airport",
"flight.airport_name,text,name of the airport",
],
"PERSON": ["flight.pilot_name,text,name of the pilot"],
}
column_join = {("airport", "country"): [("airport.country_id", "country.id")]}
Expand Down
21 changes: 18 additions & 3 deletions utils/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,32 @@ async def generate(request: Request) -> Response:
sql_lora_path = request_dict.pop("sql_lora_path", None)
request_dict.pop("sql_lora_name", None)
lora_request = (
LoRARequest("sql_adapter", 1, sql_lora_path) if sql_lora_path else None
LoRARequest(lora_name="sql_adapter", lora_int_id=1, lora_path=sql_lora_path)
if sql_lora_path
else None
)
if vllm_version >= "0.6.2":
# remove use_beam_search if present as it's no longer supported
# see https://github.com/vllm-project/vllm/releases/tag/v0.6.2
if "use_beam_search" in request_dict:
request_dict.pop("use_beam_search")
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
tokenizer = await engine.get_tokenizer()
prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
# print(f"prompt_token_ids: {prompt_token_ids}")
if prompt_token_ids[0] != tokenizer.bos_token_id:
prompt_token_ids = [tokenizer.bos_token_id] + prompt_token_ids

if vllm_version >= "0.4.2":
if vllm_version >= "0.6.3":
from vllm import TokensPrompt

results_generator = engine.generate(
prompt=TokensPrompt(prompt_token_ids=prompt_token_ids),
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
)
elif vllm_version >= "0.4.2":
results_generator = engine.generate(
inputs={"prompt_token_ids": prompt_token_ids},
sampling_params=sampling_params,
Expand Down

0 comments on commit 0262d0e

Please sign in to comment.