diff --git a/eval/eval.py b/eval/eval.py index b9cfc45..bdd1470 100644 --- a/eval/eval.py +++ b/eval/eval.py @@ -47,7 +47,7 @@ def escape_percent(match): def query_postgres_db( - query: str, db_name: str, db_creds: dict, timeout: float + query: str, db_name: str, db_creds: dict = None, timeout: float = 10.0 ) -> pd.DataFrame: """ Runs query on postgres db and returns results as a dataframe. @@ -56,6 +56,14 @@ def query_postgres_db( timeout: time in seconds to wait for query to finish before timing out """ + if db_creds is None: + db_creds = { + "host": "localhost", + "port": 5432, + "user": "postgres", + "password": "postgres", + "database": db_name, + } try: db_url = f"postgresql://{db_creds['user']}:{db_creds['password']}@{db_creds['host']}:{db_creds['port']}/{db_name}" engine = create_engine(db_url) diff --git a/eval/hf_runner.py b/eval/hf_runner.py new file mode 100644 index 0000000..509f82b --- /dev/null +++ b/eval/hf_runner.py @@ -0,0 +1,71 @@ +from eval.eval import compare_df, query_postgres_db, subset_df +import pandas as pd +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from utils.pruning import prune_metadata_str + +def prepare_questions_df(questions_file, num_questions): + question_query_df = pd.read_csv(questions_file, nrows=num_questions) + question_query_df["generated_query"] = "" + question_query_df["reason"] = "" + question_query_df["error_msg"] = "" + question_query_df["correct"] = 0 + question_query_df["subset"] = 0 + question_query_df["error_query_gen"] = 0 + question_query_df["error_db_exec"] = 0 + question_query_df["timeout"] = 0 + # add custom metrics below: + question_query_df["latency_seconds"] = 0.0 # latency of query generation in seconds + question_query_df["tokens_used"] = 0 # number of tokens used in query generation + + question_query_df.reset_index(inplace=True, drop=True) + return question_query_df + +def generate_prompt(prompt_file, question, db_name): + with open(prompt_file, "r") as f: + prompt = f.read() + + pruned_metadata_str = prune_metadata_str(question, db_name) + prompt = prompt.format(user_question = question, table_metadata_string=pruned_metadata_str) + return prompt + +def get_tokenizer_model(model_name): + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto", use_cache=True) + return tokenizer, model + +def run_hf_eval( + questions_file : str, + prompt_file: str, + num_questions: int = None, + model_name : str = "defog/starcoder-finetune-v3", + max_tokens : int = 600, +): + # get questions + df = prepare_questions_df(questions_file, num_questions) + + # create a prompt for each question + df["prompt"] = df[['question', 'db_name']].apply(lambda row: generate_prompt(prompt_file, row['question'], row['db_name']), axis=1) + + # initialize tokenizer and model + tokenizer, model = get_tokenizer_model(model_name) + + # generate predictions + eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0] + inputs = tokenizer(df["prompt"].tolist(), return_tensors="pt", padding=True) + outputs = model.generate( + **inputs, + max_new_tokens=max_tokens, + do_sample=False, + num_beams=4, + num_return_sequences=1, + eos_token_id=eos_token_id, + ) + predictions = tokenizer.batch_decode(outputs, skip_special_tokens=True) + df["prediction"] = predictions + df['generated_query'] = df['prediction'].apply(lambda x: x.split("```sql")[-1].split(";")[0].strip()) + + # from here, just do the usual eval stuff + + # export results to CSV before doing anything else + df.to_csv("hf_pred.csv", index=False) \ No newline at end of file diff --git a/eval/openai_runner.py b/eval/openai_runner.py index bfc762d..ce24220 100644 --- a/eval/openai_runner.py +++ b/eval/openai_runner.py @@ -7,12 +7,9 @@ from tqdm import tqdm -def run(args): +def run_openai_eval(args): question_query_df = pd.read_csv(args.questions_file, nrows=args.num_questions) - if args.qg_class == "oa_chat": - qg_class = OpenAIChatQueryGenerator - else: - raise ValueError(f"Unknown qg_class {args.qg_class}") + qg_class = OpenAIChatQueryGenerator # add columns for generated query and metrics question_query_df["generated_query"] = "" question_query_df["reason"] = "" diff --git a/query_generators/prompts/sample_hf_prompt.md b/query_generators/prompts/sample_hf_prompt.md new file mode 100644 index 0000000..7c3e2bb --- /dev/null +++ b/query_generators/prompts/sample_hf_prompt.md @@ -0,0 +1,17 @@ +### Instructions: +Your task is to generate a correct SQL query for answering a specific question about a SQL database. While doing this, you must adhere to the following guidelines: + +- Use exact tables and columns from database. Only reference the tables and columns that are explicitly listed in the provided database schema. +- Use Table Aliases to prevent any ambiguity. For example, `SELECT table1.column_name1, table2.column_name1 FROM table1 JOIN table2 ON table1.id = table2.id` +- Use the "ILIKE" operator with '%' wildcards when querying text columns, e.g., `column_name ILIKE '%search_text%'`. +- When creating a ratio, always cast the numerator as float, e.g., `CAST(numerator AS float) / denominator`. + +### Input: +Create a SQL query that answers the question `{user_question}`. + +The query will run on a database whose schema is represented in this string: +{table_metadata_string} + +### Response: +Here is the SQL query I have generated to answer the question `{user_question}`: +```sql \ No newline at end of file