Skip to content

Commit

Permalink
created hf runner to improve runtime efficiency
Browse files Browse the repository at this point in the history
  • Loading branch information
rishsriv committed Aug 11, 2023
1 parent a24920d commit e637991
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 6 deletions.
10 changes: 9 additions & 1 deletion eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def escape_percent(match):


def query_postgres_db(
query: str, db_name: str, db_creds: dict, timeout: float
query: str, db_name: str, db_creds: dict = None, timeout: float = 10.0
) -> pd.DataFrame:
"""
Runs query on postgres db and returns results as a dataframe.
Expand All @@ -56,6 +56,14 @@ def query_postgres_db(
timeout: time in seconds to wait for query to finish before timing out
"""
if db_creds is None:
db_creds = {
"host": "localhost",
"port": 5432,
"user": "postgres",
"password": "postgres",
"database": db_name,
}
try:
db_url = f"postgresql://{db_creds['user']}:{db_creds['password']}@{db_creds['host']}:{db_creds['port']}/{db_name}"
engine = create_engine(db_url)
Expand Down
71 changes: 71 additions & 0 deletions eval/hf_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from eval.eval import compare_df, query_postgres_db, subset_df
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils.pruning import prune_metadata_str

def prepare_questions_df(questions_file, num_questions):
question_query_df = pd.read_csv(questions_file, nrows=num_questions)
question_query_df["generated_query"] = ""
question_query_df["reason"] = ""
question_query_df["error_msg"] = ""
question_query_df["correct"] = 0
question_query_df["subset"] = 0
question_query_df["error_query_gen"] = 0
question_query_df["error_db_exec"] = 0
question_query_df["timeout"] = 0
# add custom metrics below:
question_query_df["latency_seconds"] = 0.0 # latency of query generation in seconds
question_query_df["tokens_used"] = 0 # number of tokens used in query generation

question_query_df.reset_index(inplace=True, drop=True)
return question_query_df

def generate_prompt(prompt_file, question, db_name):
with open(prompt_file, "r") as f:
prompt = f.read()

pruned_metadata_str = prune_metadata_str(question, db_name)
prompt = prompt.format(user_question = question, table_metadata_string=pruned_metadata_str)
return prompt

def get_tokenizer_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto", use_cache=True)
return tokenizer, model

def run_hf_eval(
questions_file : str,
prompt_file: str,
num_questions: int = None,
model_name : str = "defog/starcoder-finetune-v3",
max_tokens : int = 600,
):
# get questions
df = prepare_questions_df(questions_file, num_questions)

# create a prompt for each question
df["prompt"] = df[['question', 'db_name']].apply(lambda row: generate_prompt(prompt_file, row['question'], row['db_name']), axis=1)

# initialize tokenizer and model
tokenizer, model = get_tokenizer_model(model_name)

# generate predictions
eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0]
inputs = tokenizer(df["prompt"].tolist(), return_tensors="pt", padding=True)
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=False,
num_beams=4,
num_return_sequences=1,
eos_token_id=eos_token_id,
)
predictions = tokenizer.batch_decode(outputs, skip_special_tokens=True)
df["prediction"] = predictions
df['generated_query'] = df['prediction'].apply(lambda x: x.split("```sql")[-1].split(";")[0].strip())

# from here, just do the usual eval stuff

# export results to CSV before doing anything else
df.to_csv("hf_pred.csv", index=False)
7 changes: 2 additions & 5 deletions eval/openai_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@
from tqdm import tqdm


def run(args):
def run_openai_eval(args):
question_query_df = pd.read_csv(args.questions_file, nrows=args.num_questions)
if args.qg_class == "oa_chat":
qg_class = OpenAIChatQueryGenerator
else:
raise ValueError(f"Unknown qg_class {args.qg_class}")
qg_class = OpenAIChatQueryGenerator
# add columns for generated query and metrics
question_query_df["generated_query"] = ""
question_query_df["reason"] = ""
Expand Down
17 changes: 17 additions & 0 deletions query_generators/prompts/sample_hf_prompt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
### Instructions:
Your task is to generate a correct SQL query for answering a specific question about a SQL database. While doing this, you must adhere to the following guidelines:

- Use exact tables and columns from database. Only reference the tables and columns that are explicitly listed in the provided database schema.
- Use Table Aliases to prevent any ambiguity. For example, `SELECT table1.column_name1, table2.column_name1 FROM table1 JOIN table2 ON table1.id = table2.id`
- Use the "ILIKE" operator with '%' wildcards when querying text columns, e.g., `column_name ILIKE '%search_text%'`.
- When creating a ratio, always cast the numerator as float, e.g., `CAST(numerator AS float) / denominator`.

### Input:
Create a SQL query that answers the question `{user_question}`.

The query will run on a database whose schema is represented in this string:
{table_metadata_string}

### Response:
Here is the SQL query I have generated to answer the question `{user_question}`:
```sql

0 comments on commit e637991

Please sign in to comment.