Skip to content

Commit

Permalink
Tools: unify retrievers/functions and add file tools (#164)
Browse files Browse the repository at this point in the history
* Tools: unify retrievers/functions and add file tools

* lint

* add file tools

* add user_id

* chunk

* force single step

* handle direct answer

* chunk all docs

* lint

* comments + fix tests

* fix non streaming chat

* comments and fix tests

* comments

* improve rerank and chunk

* comments

* fix log

* fix migration
  • Loading branch information
lusmoura authored Jun 6, 2024
1 parent d0459af commit 42253c9
Show file tree
Hide file tree
Showing 28 changed files with 642 additions and 305 deletions.
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ xmltodict = "^0.13.0"
authlib = "^1.3.0"
itsdangerous = "^2.2.0"
bcrypt = "^4.1.2"
pypdf = "^4.2.0"
pyjwt = "^2.8.0"

[tool.poetry.group.dev]
Expand Down
30 changes: 30 additions & 0 deletions src/backend/alembic/versions/f5819b10ef2a_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""empty message
Revision ID: f5819b10ef2a
Revises: 3247f8fd3f71
Create Date: 2024-06-06 16:13:32.066454
"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "f5819b10ef2a"
down_revision: Union[str, None] = "3247f8fd3f71"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("files", sa.Column("file_content", sa.String(), nullable=False))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("files", "file_content")
# ### end Alembic commands ###
154 changes: 99 additions & 55 deletions src/backend/chat/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,78 +3,122 @@

from backend.model_deployments.base import BaseDeployment


def combine_documents(
documents: Dict[str, List[Dict[str, Any]]],
model: BaseDeployment,
) -> List[Dict[str, Any]]:
"""
Combines documents from different retrievers and reranks them.
Args:
documents (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents.
model (BaseDeployment): Model deployment.
Returns:
List[Dict[str, Any]]: List of combined documents.
"""
reranked_documents = rerank(documents, model)
return interleave(reranked_documents)
RELEVANCE_THRESHOLD = 0.5


def rerank(
documents_by_query: Dict[str, List[Dict[str, Any]]], model: BaseDeployment
) -> Dict[str, List[Dict[str, Any]]]:
def rerank_and_chunk(
tool_results: List[Dict[str, Any]], model: BaseDeployment
) -> List[Dict[str, Any]]:
"""
Takes a dictionary from queries of lists of documents and
internally rerank the documents for each query e.g:
Takes a list of tool_results and internally reranks the documents for each query, if there's one e.g:
[{"q1":[1, 2, 3],"q2": [4, 5, 6]] -> [{"q1":[2 , 3, 1],"q2": [4, 6, 5]]
Args:
documents_by_query (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents.
tool_results (List[Dict[str, Any]]): List of tool_results from different retrievers.
Each tool_result contains a ToolCall and a list of Outputs.
model (BaseDeployment): Model deployment.
Returns:
Dict[str, List[Dict[str, Any]]]: Dictionary from queries of lists of reranked documents.
List[Dict[str, Any]]: List of reranked and combined documents.
"""
# If rerank is not enabled return documents as is:
if not model.rerank_enabled:
return documents_by_query
return tool_results

# Merge all the documents with the same tool call and parameters
unified_tool_results = {}
for tool_result in tool_results:
tool_call = tool_result["call"]
tool_call_hashable = str(tool_call)

if tool_call_hashable not in unified_tool_results.keys():
unified_tool_results[tool_call_hashable] = {
"call": tool_call,
"outputs": [],
}

unified_tool_results[tool_call_hashable]["outputs"].extend(
tool_result["outputs"]
)

# Rerank the documents for each query
reranked_results = {}
for tool_call_hashable, tool_result in unified_tool_results.items():
tool_call = tool_result["call"]
query = tool_call.parameters.get("query") or tool_call.parameters.get(
"search_query"
)

# Only rerank if there is a query
if not query:
reranked_results[tool_call_hashable] = tool_result
continue

chunked_outputs = []
for output in tool_result["outputs"]:
text = output.get("text")

# rerank the documents by each query
all_rerank_docs = {}
for query, documents in documents_by_query.items():
# Only rerank on text of document
# TODO handle no text in document
docs_to_rerank = [doc["text"] for doc in documents]
if not text:
chunked_outputs.append([output])
continue

chunks = chunk(text)
chunked_outputs.extend([dict(output, text=chunk) for chunk in chunks])

# If no documents to rerank, continue to the next query
if not docs_to_rerank:
if not chunked_outputs:
continue

res = model.invoke_rerank(query=query, documents=docs_to_rerank)
res = model.invoke_rerank(query=query, documents=chunked_outputs)

# Sort the results by relevance score
res.results.sort(key=lambda x: x.relevance_score, reverse=True)
# Map the results back to the original documents
all_rerank_docs[query] = [documents[r.index] for r in res.results]

return all_rerank_docs


def interleave(documents: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
"""
Takes a dictionary from queries of lists of documents and interleaves them
for example [{"q1":[1, 2, 3],"q2": [4, 5, 6]] -> [1, 4, 2, 5, 3, 6]
Args:
documents (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents.

Returns:
List[Dict[str, Any]]: List of interleaved documents.
"""
return [
y
for x in zip_longest(*documents.values(), fillvalue=None)
for y in x
if y is not None
]
# Map the results back to the original documents
reranked_results[tool_call_hashable] = {
"call": tool_call,
"outputs": [
chunked_outputs[r.index]
for r in res.results
if r.relevance_score > RELEVANCE_THRESHOLD
],
}

return list(reranked_results.values())


def chunk(content, compact_mode=False, soft_word_cut_off=100, hard_word_cut_off=300):
if compact_mode:
content = content.replace("\n", " ")

chunks = []
current_chunk = ""
words = content.split()
word_count = 0

for word in words:
if word_count + len(word.split()) > hard_word_cut_off:
# If adding the next word exceeds the hard limit, finalize the current chunk
chunks.append(current_chunk)
current_chunk = ""
word_count = 0

if word_count + len(word.split()) > soft_word_cut_off and word.endswith("."):
# If adding the next word exceeds the soft limit and the word ends with a period, finalize the current chunk
current_chunk += " " + word
chunks.append(current_chunk.strip())
current_chunk = ""
word_count = 0
else:
# Add the word to the current chunk
if current_chunk == "":
current_chunk = word
else:
current_chunk += " " + word
word_count += len(word.split())

# Add any remaining content as the last chunk
if current_chunk != "":
chunks.append(current_chunk.strip())

return chunks
Loading

0 comments on commit 42253c9

Please sign in to comment.