From e60f87be9d5dd64d3e7eb63d6043b342794e3200 Mon Sep 17 00:00:00 2001 From: Rishabh Srivastava Date: Fri, 11 Aug 2023 22:19:26 +0000 Subject: [PATCH] formatting changes --- eval/hf_runner.py | 83 ++++++++++++++++++++++++++++------------------- main.py | 14 ++++---- utils/pruning.py | 40 ++++++++++++++++------- 3 files changed, 87 insertions(+), 50 deletions(-) diff --git a/eval/hf_runner.py b/eval/hf_runner.py index 4817dcf..a06dfad 100644 --- a/eval/hf_runner.py +++ b/eval/hf_runner.py @@ -7,6 +7,7 @@ from psycopg2.extensions import QueryCanceledError from time import time + def prepare_questions_df(questions_file, num_questions): question_query_df = pd.read_csv(questions_file, nrows=num_questions) question_query_df["generated_query"] = "" @@ -24,44 +25,55 @@ def prepare_questions_df(questions_file, num_questions): question_query_df.reset_index(inplace=True, drop=True) return question_query_df + def generate_prompt(prompt_file, question, db_name): with open(prompt_file, "r") as f: prompt = f.read() - + pruned_metadata_str = prune_metadata_str(question, db_name) - prompt = prompt.format(user_question = question, table_metadata_string=pruned_metadata_str) + prompt = prompt.format( + user_question=question, table_metadata_string=pruned_metadata_str + ) return prompt + def get_tokenizer_model(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto", use_cache=True) + model = AutoModelForCausalLM.from_pretrained( + model_name, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + device_map="auto", + use_cache=True, + ) return tokenizer, model + def run_hf_eval( - questions_file : str, + questions_file: str, prompt_file: str, num_questions: int = None, - model_name : str = "defog/starcoder-finetune-v3", - output_file : str = "results.csv", + model_name: str = "defog/starcoder-finetune-v3", + output_file: str = "results.csv", ): print("preparing questions...") # get questions df = prepare_questions_df(questions_file, num_questions) # create a prompt for each question - df["prompt"] = df[['question', 'db_name']].apply(lambda row: generate_prompt(prompt_file, row['question'], row['db_name']), axis=1) + df["prompt"] = df[["question", "db_name"]].apply( + lambda row: generate_prompt(prompt_file, row["question"], row["db_name"]), + axis=1, + ) print("questions prepared\nnow loading model...") # initialize tokenizer and model tokenizer, model = get_tokenizer_model(model_name) - + print("model loaded\nnow generating and evaluating predictions...") # generate predictions eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0] - pipe = pipeline("text-generation", - model=model, - tokenizer=tokenizer - ) + pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) # from here, just do the usual eval stuff total_tried = 0 @@ -71,17 +83,22 @@ def run_hf_eval( for row in df.to_dict("records"): total_tried += 1 start_time = time() - generated_query = pipe( - row['prompt'], - max_new_tokens=600, - do_sample=False, - num_beams=4, - num_return_sequences=1, - eos_token_id=eos_token_id, - pad_token_id=eos_token_id, - )[0]['generated_text'].split("```sql")[-1].split(";")[0].strip() + generated_query = ( + pipe( + row["prompt"], + max_new_tokens=600, + do_sample=False, + num_beams=4, + num_return_sequences=1, + eos_token_id=eos_token_id, + pad_token_id=eos_token_id, + )[0]["generated_text"] + .split("```sql")[-1] + .split(";")[0] + .strip() + ) end_time = time() - + row["generated_query"] = generated_query row["latency_seconds"] = end_time - start_time golden_query = row["query"] @@ -92,14 +109,14 @@ def run_hf_eval( generated_result = expected_result = None try: - expected_result = query_postgres_db( - golden_query, db_name - ).rename(columns=str.lower) - - generated_result = query_postgres_db( - generated_query, db_name - ).rename(columns=str.lower) - + expected_result = query_postgres_db(golden_query, db_name).rename( + columns=str.lower + ) + + generated_result = query_postgres_db(generated_query, db_name).rename( + columns=str.lower + ) + correct = subset = int( compare_df( expected_result, generated_result, query_category, question @@ -123,13 +140,13 @@ def run_hf_eval( except Exception as e: row["error_db_exec"] = 1 row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - + output_rows.append(row) pbar.update(1) pbar.set_description( f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" ) - + output_df = pd.DataFrame(output_rows) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - output_df.to_csv(output_file, index=False, float_format="%.2f") \ No newline at end of file + output_df.to_csv(output_file, index=False, float_format="%.2f") diff --git a/main.py b/main.py index e8b0f36..475b356 100644 --- a/main.py +++ b/main.py @@ -22,11 +22,13 @@ run_openai_eval(args) elif args.model_type == "hf": run_hf_eval( - questions_file = args.questions_file, - prompt_file = args.prompt_file, - num_questions = args.num_questions, - model_name = args.model, - output_file = args.output_file, + questions_file=args.questions_file, + prompt_file=args.prompt_file, + num_questions=args.num_questions, + model_name=args.model, + output_file=args.output_file, ) else: - raise ValueError(f"Invalid model type: {args.model_type}. Model type must be one of: 'openai', 'hf'") + raise ValueError( + f"Invalid model type: {args.model_type}. Model type must be one of: 'openai', 'hf'" + ) diff --git a/utils/pruning.py b/utils/pruning.py index a954149..af0a4c0 100644 --- a/utils/pruning.py +++ b/utils/pruning.py @@ -8,6 +8,7 @@ encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") nlp = spacy.load("en_core_web_sm") + def load_all_emb() -> Tuple[Dict[str, torch.tensor], List[str]]: """ Load all embeddings from pickle file. @@ -20,6 +21,7 @@ def load_all_emb() -> Tuple[Dict[str, torch.tensor], List[str]]: print("Embeddings not found.") exit(1) + def load_ner_md() -> Tuple[Dict[str, Dict], Dict[str, Dict], Dict[str, Dict]]: """ Load all NER and join metadata from pickle file. @@ -32,6 +34,7 @@ def load_ner_md() -> Tuple[Dict[str, Dict], Dict[str, Dict], Dict[str, Dict]]: print("NER and join metadata not found.") exit(1) + def knn( query: str, all_emb: torch.tensor, @@ -49,13 +52,16 @@ def knn( return torch.tensor([]), torch.tensor([]) # if only 1 result is returned, we need to convert it to a tensor elif top_results.numel() == 1: - return torch.tensor([similarity_scores[top_results]]), torch.tensor([top_results]) + return torch.tensor([similarity_scores[top_results]]), torch.tensor( + [top_results] + ) else: top_k_scores, top_k_indices = torch.topk( similarity_scores[top_results], k=min(k, top_results.numel()) ) return top_k_scores, top_results[top_k_indices] + def get_entity_types(sentence, verbose: bool = False): """ Get entity types from sentence using spaCy. @@ -69,7 +75,11 @@ def get_entity_types(sentence, verbose: bool = False): return named_entities -def format_topk_sql(topk_table_columns: Dict[str, List[Tuple[str, str, str]]], exclude_column_descriptions: bool = False) -> str: + +def format_topk_sql( + topk_table_columns: Dict[str, List[Tuple[str, str, str]]], + exclude_column_descriptions: bool = False, +) -> str: md_str = "```\n" for table_name in topk_table_columns: columns_str = "" @@ -77,13 +87,16 @@ def format_topk_sql(topk_table_columns: Dict[str, List[Tuple[str, str, str]]], e if exclude_column_descriptions: columns_str += f"\n {column_tuple[0]} {column_tuple[1]}," else: - columns_str += f"\n {column_tuple[0]} {column_tuple[1]}, --{column_tuple[2]}" + columns_str += ( + f"\n {column_tuple[0]} {column_tuple[1]}, --{column_tuple[2]}" + ) md_str += f"CREATE TABLE {table_name} ({columns_str}\n)\n-----------\n" return md_str + def get_md_emb( question: str, - column_emb: torch.tensor, + column_emb: torch.tensor, column_info_csv: List[str], column_ner: Dict[str, List[str]], column_join: Dict[str, dict], @@ -114,13 +127,15 @@ def get_md_emb( topk_table_columns[table_name] = [] topk_table_columns[table_name].append(column_tuple) table_column_names.add(f"{table_name}.{column_tuple[0]}") - + # 2) get entity types from question + add corresponding columns entity_types = get_entity_types(question) for entity_type in entity_types: if entity_type in column_ner: for column_info in column_ner[entity_type]: - table_column_name, column_type, column_description = column_info.split(",", 2) + table_column_name, column_type, column_description = column_info.split( + ",", 2 + ) table_name, column_name = table_column_name.split(".", 1) if table_name not in topk_table_columns: topk_table_columns[table_name] = [] @@ -128,13 +143,15 @@ def get_md_emb( if column_tuple not in topk_table_columns[table_name]: topk_table_columns[table_name].append(column_tuple) topk_tables = sorted(list(topk_table_columns.keys())) - + # 3) get table pairs that can be joined # create dict of table_column_name -> column_tuple for lookups column_name_to_tuple = {} ncols = len(column_info_csv) for i in range(ncols): - table_column_name, column_type, column_description = column_info_csv[i].split(",", 2) + table_column_name, column_type, column_description = column_info_csv[i].split( + ",", 2 + ) table_name, column_name = table_column_name.split(".", 1) column_tuple = (column_name, column_type, column_description) column_name_to_tuple[table_column_name] = column_tuple @@ -159,16 +176,17 @@ def get_md_emb( join_str = f"{table_col_1} can be joined with {table_col_2}" if join_str not in join_list: join_list.append(join_str) - + # 4) format metadata string md_str = format_topk_sql(topk_table_columns, exclude_column_descriptions) - + if join_list: md_str += "```\n\nAdditionally, is a list of joinable columns in this database schema:\n```\n" md_str += "\n".join(join_list) md_str += "\n```" return md_str + def prune_metadata_str(question, db_name, exclude_column_descriptions=False): emb, csv_descriptions = load_all_emb() columns_ner, columns_join = load_ner_md() @@ -180,4 +198,4 @@ def prune_metadata_str(question, db_name, exclude_column_descriptions=False): columns_join[db_name], exclude_column_descriptions=exclude_column_descriptions, ) - return table_metadata_csv \ No newline at end of file + return table_metadata_csv