-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
created hf runner to improve runtime efficiency
- Loading branch information
Showing
4 changed files
with
99 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from eval.eval import compare_df, query_postgres_db, subset_df | ||
import pandas as pd | ||
import torch | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
from utils.pruning import prune_metadata_str | ||
|
||
def prepare_questions_df(questions_file, num_questions): | ||
question_query_df = pd.read_csv(questions_file, nrows=num_questions) | ||
question_query_df["generated_query"] = "" | ||
question_query_df["reason"] = "" | ||
question_query_df["error_msg"] = "" | ||
question_query_df["correct"] = 0 | ||
question_query_df["subset"] = 0 | ||
question_query_df["error_query_gen"] = 0 | ||
question_query_df["error_db_exec"] = 0 | ||
question_query_df["timeout"] = 0 | ||
# add custom metrics below: | ||
question_query_df["latency_seconds"] = 0.0 # latency of query generation in seconds | ||
question_query_df["tokens_used"] = 0 # number of tokens used in query generation | ||
|
||
question_query_df.reset_index(inplace=True, drop=True) | ||
return question_query_df | ||
|
||
def generate_prompt(prompt_file, question, db_name): | ||
with open(prompt_file, "r") as f: | ||
prompt = f.read() | ||
|
||
pruned_metadata_str = prune_metadata_str(question, db_name) | ||
prompt = prompt.format(user_question = question, table_metadata_string=pruned_metadata_str) | ||
return prompt | ||
|
||
def get_tokenizer_model(model_name): | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto", use_cache=True) | ||
return tokenizer, model | ||
|
||
def run_hf_eval( | ||
questions_file : str, | ||
prompt_file: str, | ||
num_questions: int = None, | ||
model_name : str = "defog/starcoder-finetune-v3", | ||
max_tokens : int = 600, | ||
): | ||
# get questions | ||
df = prepare_questions_df(questions_file, num_questions) | ||
|
||
# create a prompt for each question | ||
df["prompt"] = df[['question', 'db_name']].apply(lambda row: generate_prompt(prompt_file, row['question'], row['db_name']), axis=1) | ||
|
||
# initialize tokenizer and model | ||
tokenizer, model = get_tokenizer_model(model_name) | ||
|
||
# generate predictions | ||
eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0] | ||
inputs = tokenizer(df["prompt"].tolist(), return_tensors="pt", padding=True) | ||
outputs = model.generate( | ||
**inputs, | ||
max_new_tokens=max_tokens, | ||
do_sample=False, | ||
num_beams=4, | ||
num_return_sequences=1, | ||
eos_token_id=eos_token_id, | ||
) | ||
predictions = tokenizer.batch_decode(outputs, skip_special_tokens=True) | ||
df["prediction"] = predictions | ||
df['generated_query'] = df['prediction'].apply(lambda x: x.split("```sql")[-1].split(";")[0].strip()) | ||
|
||
# from here, just do the usual eval stuff | ||
|
||
# export results to CSV before doing anything else | ||
df.to_csv("hf_pred.csv", index=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
### Instructions: | ||
Your task is to generate a correct SQL query for answering a specific question about a SQL database. While doing this, you must adhere to the following guidelines: | ||
|
||
- Use exact tables and columns from database. Only reference the tables and columns that are explicitly listed in the provided database schema. | ||
- Use Table Aliases to prevent any ambiguity. For example, `SELECT table1.column_name1, table2.column_name1 FROM table1 JOIN table2 ON table1.id = table2.id` | ||
- Use the "ILIKE" operator with '%' wildcards when querying text columns, e.g., `column_name ILIKE '%search_text%'`. | ||
- When creating a ratio, always cast the numerator as float, e.g., `CAST(numerator AS float) / denominator`. | ||
|
||
### Input: | ||
Create a SQL query that answers the question `{user_question}`. | ||
|
||
The query will run on a database whose schema is represented in this string: | ||
{table_metadata_string} | ||
|
||
### Response: | ||
Here is the SQL query I have generated to answer the question `{user_question}`: | ||
```sql |