From 8fd009a50566575ed81e527fd140c6f426e6f28b Mon Sep 17 00:00:00 2001 From: pramitchoudhary Date: Thu, 26 Oct 2023 12:14:29 -0700 Subject: [PATCH] Add additionl re-ranking step using question length post semantic ranking #49 --- sidekick/query.py | 36 +++++++++++++++--------------------- sidekick/utils.py | 15 ++++++++++++++- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/sidekick/query.py b/sidekick/query.py index 7a7afa6..4a29b70 100644 --- a/sidekick/query.py +++ b/sidekick/query.py @@ -12,26 +12,17 @@ import torch import torch.nn.functional as F from langchain import OpenAI -from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase +from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex, + LLMPredictor, ServiceContext, SQLDatabase) from llama_index.indices.struct_store import SQLContextContainerBuilder -from sidekick.configs.prompt_template import ( - DEBUGGING_PROMPT, - NSQL_QUERY_PROMPT, - QUERY_PROMPT, - STARCODER2_PROMPT, - TASK_PROMPT, -) +from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, + NSQL_QUERY_PROMPT, QUERY_PROMPT, + STARCODER2_PROMPT, TASK_PROMPT) from sidekick.logger import logger -from sidekick.utils import ( - _check_file_info, - filter_samples, - is_resource_low, - load_causal_lm_model, - load_embedding_model, - make_dir, - read_sample_pairs, - remove_duplicates, -) +from sidekick.utils import (_check_file_info, is_resource_low, + load_causal_lm_model, load_embedding_model, + make_dir, re_rank, read_sample_pairs, + remove_duplicates, semantic_search) from sqlalchemy import create_engine @@ -259,7 +250,7 @@ def generate_tasks(self, table_names: list, input_question: str): # Filter closest samples to the input question, threshold = 0.45 filtered_context = ( - filter_samples( + semantic_search( input_question, updated_context, m_path, @@ -375,7 +366,7 @@ def generate_sql( } m_path = f"{self.path}/models/sentence_transformers/" - filtered_context = filter_samples( + filtered_context = semantic_search( model_obj=self.similarity_model, input_q=input_question, probable_qs=list(_context.keys()), @@ -401,7 +392,7 @@ def generate_sql( # Filter closest samples to the input question, threshold = 0.9 filtered_context = ( - filter_samples( + semantic_search( input_q=input_question, probable_qs=context_queries, model_path=m_path, @@ -414,9 +405,12 @@ def generate_sql( ) logger.info(f"Number of possible contextual queries to question: {len(filtered_context)}") # If QnA pairs > 5, we keep top 5 for focused context + # Most relevant match is closest to the generation post re-ranking _samples = filtered_context + _samples = re_rank(input_question, _samples) if len(filtered_context) > 5: _samples = filtered_context[0:5][::-1] + _samples = re_rank(input_question, _samples) qna_samples = "\n".join(_samples) diff --git a/sidekick/utils.py b/sidekick/utils.py index ce02359..c0950e6 100644 --- a/sidekick/utils.py +++ b/sidekick/utils.py @@ -33,6 +33,7 @@ "sqld": "SQL Debugging", } + def generate_sentence_embeddings(model_path: str, x, batch_size: int = 32, device: Optional[str] = None): # Reference: # 1. https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models @@ -85,7 +86,19 @@ def generate_text_embeddings(model_path: str, x, model_obj=None, batch_size: int return res -def filter_samples( +def re_rank(question: str, input_x: list): + # Currently using question length as final step to re-rank, might change in future + input_pqs = [_se.strip().lower().split("answer:")[0].strip() for _se in input_x[0:5]] + _dist = np.array([len(_in.split()) for _in in input_pqs]) + + query_len = len(question.lower().split()) + logger.debug(f"Question length: {query_len}") + sorted_ = np.argsort(abs(_dist - query_len))[::-1].tolist() + res = list(np.array(input_x)[sorted_]) + return res + + +def semantic_search( input_q: str, probable_qs: list, model_path: str,