Skip to content

Commit

Permalink
Update hybrid search weights (#231)
Browse files Browse the repository at this point in the history
* Update hybrid search weights
* Add SAM template outputs

---------

Co-authored-by: Brendan Quinn <[email protected]>
  • Loading branch information
mbklein and bmquinn authored Jul 22, 2024
1 parent df48b0f commit 6c84bdc
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 36 deletions.
8 changes: 4 additions & 4 deletions chat/src/handlers/opensearch_neural_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ def __init__(
self.text_field = text_field

def similarity_search(
self, query: str, k: int = 10, subquery: Any = None, **kwargs: Any
self, query: str, k: int = 10, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to the embedding vector."""
docs_with_scores = self.similarity_search_with_score(
query, k, subquery, **kwargs
query, k, **kwargs
)
return [doc[0] for doc in docs_with_scores]

def similarity_search_with_score(
self, query: str, k: int = 10, subquery: Any = None, **kwargs: Any
self, query: str, k: int = 10, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query."""
dsl = hybrid_query(query=query, model_id=self.model_id, vector_field=self.vector_field, k=k, subquery=subquery, **kwargs)
dsl = hybrid_query(query=query, model_id=self.model_id, vector_field=self.vector_field, k=k, **kwargs)
response = self.client.search(index=self.index, body=dsl, params={"search_pipeline": self.search_pipeline} if self.search_pipeline else None)
documents_with_scores = [
(
Expand Down
23 changes: 7 additions & 16 deletions chat/src/helpers/hybrid_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,18 @@ def filter(query: dict):
}
}

def hybrid_query(query: str, model_id: str, vector_field: str = "embedding", k: int = 10, subquery: Any = None, **kwargs: Any):
if subquery:
weights = [0.5, 0.3, 0.2]
else:
weights = [0.7, 0.3]

def hybrid_query(query: str, model_id: str, vector_field: str = "embedding", k: int = 10, **kwargs: Any):
result = {
"size": k,
"query": {
"hybrid": {
"queries": [
filter({
"query_string": {
"default_operator": "AND",
"fields": ["title^5", "all_controlled_labels", "all_ids^5"],
"query": query
"default_operator": "AND",
"fields": ["all_titles^5", "all_controlled_labels", "all_ids^5"],
"query": query,
"analyzer": "english"
}
}),
filter({
Expand All @@ -47,7 +43,7 @@ def hybrid_query(query: str, model_id: str, vector_field: str = "embedding", k:
"normalization-processor": {
"combination": {
"parameters": {
"weights": weights
"weights": [0.25, 0.75]
},
"technique": "arithmetic_mean"
},
Expand All @@ -60,12 +56,7 @@ def hybrid_query(query: str, model_id: str, vector_field: str = "embedding", k:
}
}

if subquery:
result["query"]["hybrid"]["queries"].append(filter(subquery))

for key, value in kwargs.items():
result[key] = value

return result


11 changes: 1 addition & 10 deletions chat/src/helpers/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,7 @@ def get_and_send_original_question(docs):

def prepare_response(self):
try:
subquery = {
"match": {
"all_titles": {
"query": self.config.question,
"operator": "AND",
"analyzer": "english"
}
}
}
retriever = self.config.opensearch.as_retriever(search_type="similarity", search_kwargs={"k": self.config.k, "subquery": subquery, "_source": {"excludes": ["embedding"]}})
retriever = self.config.opensearch.as_retriever(search_type="similarity", search_kwargs={"k": self.config.k, "_source": {"excludes": ["embedding"]}})
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| self.original_question_passthrough()
Expand Down
10 changes: 4 additions & 6 deletions chat/test/helpers/test_hybrid_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,17 @@

class TestFunction(TestCase):
def test_hybrid_query(self):
subquery = { "term": { "title": { "value": "The Title" } } }
dsl = hybrid_query("Question?", "MODEL_ID", k=10, subquery=subquery)
dsl = hybrid_query("Question?", "MODEL_ID", k=10)
subject = dsl["query"]["hybrid"]["queries"]

checks = [
(lambda x: x["query_string"]["query"], "Question?"),
(lambda x: x["neural"]["embedding"]["model_id"], "MODEL_ID"),
(lambda x: x["term"]["title"]["value"], "The Title")
(lambda x: x["neural"]["embedding"]["model_id"], "MODEL_ID")
]

self.assertEqual(len(subject), 3)
self.assertEqual(len(subject), 2)

for i in range(3):
for i in range(2):
lookup, expected = checks[i]
queries = subject[i]["bool"]["must"]
self.assertEqual(lookup(queries[0]), expected)
Expand Down
4 changes: 4 additions & 0 deletions template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1159,3 +1159,7 @@ Resources:
AuthorizationType: NONE
RouteKey: GET /docs/v2/{proxy+}
Target: !Sub "integrations/${docsIntegration}"
Outputs:
Endpoint:
Description: "The base API endpoint for the stack"
Value: !Sub "https://${CustomDomainHost}.${CustomDomainZone}/api/v2"

0 comments on commit 6c84bdc

Please sign in to comment.