Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Together.ai Runner #215

Merged
merged 2 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,38 @@ python -W ignore main.py \
-n 10
```

### Bedrock

```bash
python3 main.py \
-db postgres \
-q data/instruct_basic_postgres.csv data/instruct_advanced_postgres.csv data/questions_gen_postgres.csv \
-o results/bedrock_llama_70b_basic.csv results/bedrock_llama_70b_advanced.csv results/bedrock_llama_70b_v1.csv \
-g bedrock \
-f prompts/prompt_cot_postgres.md \
--cot_table_alias prealias \
-m meta.llama3-70b-instruct-v1:0 \
-c 0 \
-p 10
```

### Together

Before running this, you must create an account with [Together.ai](https://together.ai/) and obtain an API key and store it with `export TOGETHER_API_KEY=<your_api_key>`. Then, install `together` with `pip install together`. You can then run the following command:

```bash
python3 main.py \
-db postgres \
-q data/instruct_basic_postgres.csv data/instruct_advanced_postgres.csv data/questions_gen_postgres.csv \
-o results/together_llama_70b_basic.csv results/together_llama_70b_advanced.csv results/together_llama_70b_v1.csv \
-g together \
-f prompts/prompt_together.json \
--cot_table_alias prealias \
-m "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \
-c 0 \
-p 10
```

## CLI Flags

You can use the following flags in the command line to change the configurations of your evaluation runs.
Expand All @@ -397,7 +429,7 @@ You can use the following flags in the command line to change the configurations

| CLI Flags | Description |
| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| -g, --model_type | Model type used. Make sure this matches the model used. Currently defined options in `main.py` are `oa` for OpenAI models, `anthropic` for Anthropic models, `hf` for Hugging Face models, `vllm` for a vllm runner, `api` for API endpoints, `llama_cpp` for llama cpp, and `mlx` for mlx |
| -g, --model_type | Model type used. Make sure this matches the model used. Currently defined options in `main.py` are `oa` for OpenAI models, `anthropic` for Anthropic models, `hf` for Hugging Face models, `vllm` for a vllm runner, `api` for API endpoints, `llama_cpp` for llama cpp, `mlx` for mlx, `bedrock` for AWS bedrock API, `together` for together.ai's API |
| -m, --model | Model that will be tested and used to generate the queries. Some options for OpenAI models are chat models `gpt-3.5-turbo-0613` and `gpt-4-0613`. Options for Anthropic include the latest claude-3 family of models (e.g. `claude-3-opus-20240229`). For Hugging Face, and VLLM models, simply use the path of your chosen model (e.g. `defog/sqlcoder`). |
| -a, --adapter | Path to the relevant adapter model you're using. Only available for the `hf_runner`. |
| --api_url | The URL of the custom API you want to send the prompt to. Only used when model_type is `api`. |
Expand Down
174 changes: 174 additions & 0 deletions eval/together_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict

from eval.eval import compare_query_results
import pandas as pd
from utils.gen_prompt import generate_prompt
from utils.questions import prepare_questions_df
from utils.creds import db_creds_all
from tqdm import tqdm
from time import time
from together import Together
from utils.reporting import upload_results


client = Together(api_key=os.environ.get("TOGETHER_API_KEY"))


def process_row(row: Dict, model: str):
start_time = time()
if model.startswith("meta-llama"):
stop = ["<|eot_id|>", "<|eom_id|>"]
else:
print(
"Undefined stop token(s). Please specify the stop token(s) for the model."
)
stop = []
messages = row["prompt"]
response = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=800,
temperature=0.0,
stop=stop,
stream=False,
)
content = response.choices[0].message.content
generated_query = content.split("```", 1)[0].strip()
end_time = time()

row["generated_query"] = generated_query
row["latency_seconds"] = end_time - start_time
row["tokens_used"] = None
golden_query = row["query"]
db_name = row["db_name"]
db_type = row["db_type"]
question = row["question"]
query_category = row["query_category"]
table_metadata_string = row["table_metadata_string"]
exact_match = correct = 0

try:
exact_match, correct = compare_query_results(
query_gold=golden_query,
query_gen=generated_query,
db_name=db_name,
db_type=db_type,
db_creds=db_creds_all[row["db_type"]],
question=question,
query_category=query_category,
table_metadata_string=table_metadata_string,
)
row["exact_match"] = int(exact_match)
row["correct"] = int(correct)
row["error_msg"] = ""
except Exception as e:
row["error_db_exec"] = 1
row["error_msg"] = f"QUERY EXECUTION ERROR: {e}"

return row


def run_together_eval(args):
# get params from args
questions_file_list = args.questions_file
prompt_file_list = args.prompt_file
num_questions = args.num_questions
public_data = not args.use_private_data
output_file_list = args.output_file
k_shot = args.k_shot
max_workers = args.parallel_threads
db_type = args.db_type
decimal_points = args.decimal_points
model = args.model
cot_table_alias = args.cot_table_alias

for questions_file, prompt_file, output_file in zip(
questions_file_list, prompt_file_list, output_file_list
):
if not prompt_file.endswith(".json"):
raise ValueError(f"Prompt file must be a JSON file. Got {prompt_file}")
print(f"Using prompt file {prompt_file}")
# get questions
print("Preparing questions...")
print(
f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}"
)
df = prepare_questions_df(
questions_file, db_type, num_questions, k_shot, cot_table_alias
)
# create a prompt for each question
# note that the prompt for together ai uses the openai chat API
df["prompt"] = df.apply(
lambda row: generate_prompt(
prompt_file,
row["question"],
row["db_name"],
row["db_type"],
row["instructions"],
row["k_shot_prompt"],
row["glossary"],
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["question_0"],
row["query_0"],
row["question_1"],
row["query_1"],
row["cot_instructions"],
row["cot_pregen"],
public_data,
args.num_columns,
args.shuffle_metadata,
row["table_aliases"],
),
axis=1,
)

total_tried = 0
total_correct = 0
output_rows = []

with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for row in df.to_dict("records"):
futures.append(executor.submit(process_row, row, model))

with tqdm(as_completed(futures), total=len(futures)) as pbar:
for f in pbar:
row = f.result()
output_rows.append(row)
if row["correct"]:
total_correct += 1
total_tried += 1
pbar.update(1)
pbar.set_description(
f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)"
)

output_df = pd.DataFrame(output_rows)
del output_df["prompt"]
print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean())
output_df = output_df.sort_values(by=["db_name", "query_category", "question"])
# get directory of output_file and create if not exist
output_dir = os.path.dirname(output_file)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
try:
output_df.to_csv(output_file, index=False, float_format="%.2f")
except:
output_df.to_pickle(output_file)

results = output_df.to_dict("records")
# upload results
with open(prompt_file, "r") as f:
prompt = f.read()
if args.upload_url is not None:
upload_results(
results=results,
url=args.upload_url,
runner_type="api_runner",
prompt=prompt,
args=args,
)
4 changes: 4 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@
from eval.bedrock_runner import run_bedrock_eval

run_bedrock_eval(args)
elif args.model_type == "together":
from eval.together_runner import run_together_eval

run_together_eval(args)
else:
raise ValueError(
f"Invalid model type: {args.model_type}. Model type must be one of: 'oa', 'hf', 'anthropic', 'vllm', 'api', 'llama_cpp', 'mlx', 'gemini', 'mistral'"
Expand Down
14 changes: 14 additions & 0 deletions prompts/prompt_together.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[
{
"role": "system",
"content": "Your role is to convert a user question to a {db_type} query, given a database schema."
},
{
"role": "user",
"content": "Generate a SQL query that answers the question `{user_question}`.\n{instructions}\nThis query will run on a database whose schema is represented in this SQL DDL:\n{table_metadata_string}\n{table_aliases}\n{pruned_join_str}\n{k_shot_prompt}\nReturn the SQL query that answers the question `{user_question}`"
},
{
"role": "assistant",
"content": "```sql\n"
}
]
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ snowflake-connector-python
spacy
sqlalchemy
tiktoken
together
torch
tqdm
transformers
Expand Down
89 changes: 64 additions & 25 deletions utils/gen_prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from copy import deepcopy
import json
from typing import Dict, List, Optional
import numpy as np
from utils.dialects import (
Expand Down Expand Up @@ -120,11 +122,23 @@ def generate_prompt(
public_data=True,
columns_to_keep=40,
shuffle_metadata=False,
table_aliases="",
):
"""
Generates the prompt for the given question.
If a json file is passed in as the prompt_file, please ensure that it is a list
of dictionaries, which should have the `content` key minimally.
Else, we will treat the file as a string template.
"""
from defog_data.metadata import dbs # to avoid CI error

with open(prompt_file, "r") as f:
prompt = f.read()
is_json = prompt_file.endswith(".json")
if is_json:
with open(prompt_file, "r") as f:
messages_template = json.load(f)
else:
with open(prompt_file, "r") as f:
prompt = f.read()
question_instructions = question + " " + instructions
table_names = []

Expand Down Expand Up @@ -232,26 +246,51 @@ def generate_prompt(
)
instruction_reflections = instruction_reflections + "\n"

prompt = prompt.format(
user_question=question,
db_type=db_type,
instructions=instructions,
table_metadata_string=table_metadata_string,
k_shot_prompt=k_shot_prompt,
glossary=glossary,
prev_invalid_sql=prev_invalid_sql,
prev_error_msg=prev_error_msg,
question_0=question_0,
query_0=query_0,
question_1=question_1,
query_1=query_1,
cot_instructions=cot_instructions,
instruction_reflections=instruction_reflections,
join_hints=join_str,
pruned_join_hints=pruned_join_str,
)

if cot_pregen:
table_aliases = generate_aliases(table_names)
prompt = prompt + table_aliases
return prompt
if is_json:
messages = []
for msg_template in messages_template:
msg = deepcopy(msg_template)
msg["content"] = msg_template["content"].format(
user_question=question,
db_type=db_type,
instructions=instructions,
table_metadata_string=table_metadata_string,
k_shot_prompt=k_shot_prompt,
glossary=glossary,
prev_invalid_sql=prev_invalid_sql,
prev_error_msg=prev_error_msg,
question_0=question_0,
query_0=query_0,
question_1=question_1,
query_1=query_1,
cot_instructions=cot_instructions,
instruction_reflections=instruction_reflections,
table_aliases=table_aliases,
join_str=join_str,
pruned_join_str=pruned_join_str,
)
messages.append(msg)
return messages
else:
prompt = prompt.format(
user_question=question,
db_type=db_type,
instructions=instructions,
table_metadata_string=table_metadata_string,
k_shot_prompt=k_shot_prompt,
glossary=glossary,
prev_invalid_sql=prev_invalid_sql,
prev_error_msg=prev_error_msg,
question_0=question_0,
query_0=query_0,
question_1=question_1,
query_1=query_1,
cot_instructions=cot_instructions,
instruction_reflections=instruction_reflections,
join_hints=join_str,
pruned_join_hints=pruned_join_str,
)
if cot_pregen:
table_aliases = generate_aliases(table_names)
prompt = prompt + table_aliases
return prompt
4 changes: 4 additions & 0 deletions utils/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,12 @@ def prepare_questions_df(
question_query_df["cot_instructions"] = question_query_df["db_name"].apply(
get_table_aliases
)
question_query_df["table_aliases"] = question_query_df["db_name"].apply(
get_table_aliases
)
else:
question_query_df["cot_instructions"] = ""
question_query_df["table_aliases"] = ""

if cot_table_alias == "pregen":
question_query_df["cot_pregen"] = True
Expand Down
Loading