Skip to content

Commit

Permalink
Add additional re-ranking step using question length post semantic ra…
Browse files Browse the repository at this point in the history
…nking/selection
  • Loading branch information
pramitchoudhary authored Oct 26, 2023
2 parents c406b97 + 8fd009a commit d660581
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 22 deletions.
36 changes: 15 additions & 21 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,17 @@
import torch
import torch.nn.functional as F
from langchain import OpenAI
from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
LLMPredictor, ServiceContext, SQLDatabase)
from llama_index.indices.struct_store import SQLContextContainerBuilder
from sidekick.configs.prompt_template import (
DEBUGGING_PROMPT,
NSQL_QUERY_PROMPT,
QUERY_PROMPT,
STARCODER2_PROMPT,
TASK_PROMPT,
)
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT,
NSQL_QUERY_PROMPT, QUERY_PROMPT,
STARCODER2_PROMPT, TASK_PROMPT)
from sidekick.logger import logger
from sidekick.utils import (
_check_file_info,
filter_samples,
is_resource_low,
load_causal_lm_model,
load_embedding_model,
make_dir,
read_sample_pairs,
remove_duplicates,
)
from sidekick.utils import (_check_file_info, is_resource_low,
load_causal_lm_model, load_embedding_model,
make_dir, re_rank, read_sample_pairs,
remove_duplicates, semantic_search)
from sqlalchemy import create_engine


Expand Down Expand Up @@ -259,7 +250,7 @@ def generate_tasks(self, table_names: list, input_question: str):

# Filter closest samples to the input question, threshold = 0.45
filtered_context = (
filter_samples(
semantic_search(
input_question,
updated_context,
m_path,
Expand Down Expand Up @@ -375,7 +366,7 @@ def generate_sql(
}

m_path = f"{self.path}/models/sentence_transformers/"
filtered_context = filter_samples(
filtered_context = semantic_search(
model_obj=self.similarity_model,
input_q=input_question,
probable_qs=list(_context.keys()),
Expand All @@ -401,7 +392,7 @@ def generate_sql(

# Filter closest samples to the input question, threshold = 0.9
filtered_context = (
filter_samples(
semantic_search(
input_q=input_question,
probable_qs=context_queries,
model_path=m_path,
Expand All @@ -414,9 +405,12 @@ def generate_sql(
)
logger.info(f"Number of possible contextual queries to question: {len(filtered_context)}")
# If QnA pairs > 5, we keep top 5 for focused context
# Most relevant match is closest to the generation post re-ranking
_samples = filtered_context
_samples = re_rank(input_question, _samples)
if len(filtered_context) > 5:
_samples = filtered_context[0:5][::-1]
_samples = re_rank(input_question, _samples)

qna_samples = "\n".join(_samples)

Expand Down
15 changes: 14 additions & 1 deletion sidekick/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"sqld": "SQL Debugging",
}


def generate_sentence_embeddings(model_path: str, x, batch_size: int = 32, device: Optional[str] = None):
# Reference:
# 1. https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models
Expand Down Expand Up @@ -85,7 +86,19 @@ def generate_text_embeddings(model_path: str, x, model_obj=None, batch_size: int
return res


def filter_samples(
def re_rank(question: str, input_x: list):
# Currently using question length as final step to re-rank, might change in future
input_pqs = [_se.strip().lower().split("answer:")[0].strip() for _se in input_x[0:5]]
_dist = np.array([len(_in.split()) for _in in input_pqs])

query_len = len(question.lower().split())
logger.debug(f"Question length: {query_len}")
sorted_ = np.argsort(abs(_dist - query_len))[::-1].tolist()
res = list(np.array(input_x)[sorted_])
return res


def semantic_search(
input_q: str,
probable_qs: list,
model_path: str,
Expand Down

0 comments on commit d660581

Please sign in to comment.