From 42253c933ab846f8e369f281cad77f90fac30e57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADsa=20Moura?= Date: Thu, 6 Jun 2024 13:10:06 -0400 Subject: [PATCH] Tools: unify retrievers/functions and add file tools (#164) * Tools: unify retrievers/functions and add file tools * lint * add file tools * add user_id * chunk * force single step * handle direct answer * chunk all docs * lint * comments + fix tests * fix non streaming chat * comments and fix tests * comments * improve rerank and chunk * comments * fix log * fix migration --- poetry.lock | 4 +- pyproject.toml | 1 + src/backend/alembic/versions/f5819b10ef2a_.py | 30 ++ src/backend/chat/collate.py | 154 ++++++---- src/backend/chat/custom/custom.py | 274 ++++++++++-------- src/backend/config/tools.py | 45 ++- src/backend/crud/file.py | 21 ++ src/backend/database_models/file.py | 1 + src/backend/model_deployments/azure.py | 17 +- src/backend/model_deployments/base.py | 8 +- src/backend/model_deployments/bedrock.py | 10 +- .../model_deployments/cohere_platform.py | 26 +- src/backend/model_deployments/sagemaker.py | 11 +- src/backend/routers/chat.py | 3 + src/backend/routers/conversation.py | 5 + src/backend/schemas/chat.py | 5 +- src/backend/schemas/cohere_chat.py | 4 + src/backend/services/chat.py | 18 +- .../mock_deployments/mock_azure.py | 4 +- .../mock_deployments/mock_bedrock.py | 26 ++ .../mock_deployments/mock_cohere_platform.py | 7 +- .../tests/model_deployments/test_azure.py | 24 -- .../tests/model_deployments/test_bedrock.py | 16 + .../model_deployments/test_cohere_platform.py | 24 -- .../tests/model_deployments/test_sagemaker.py | 26 +- src/backend/tests/tools/test_collate.py | 85 ++++-- src/backend/tools/__init__.py | 3 + src/backend/tools/files.py | 95 ++++++ 28 files changed, 642 insertions(+), 305 deletions(-) create mode 100644 src/backend/alembic/versions/f5819b10ef2a_.py create mode 100644 src/backend/tools/files.py diff --git a/poetry.lock b/poetry.lock index 49952baf98..849450be12 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -5593,4 +5593,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "~3.11" -content-hash = "fe6d0a5e2663dc702369cf498d6a2e27d0174be5be09ac75058b5148b8edd465" +content-hash = "f46b43bab59f8a25270968cd0fe95c5b6801474c2e649ffc915c1b82340296eb" diff --git a/pyproject.toml b/pyproject.toml index 64a0876a86..62755096ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ xmltodict = "^0.13.0" authlib = "^1.3.0" itsdangerous = "^2.2.0" bcrypt = "^4.1.2" +pypdf = "^4.2.0" pyjwt = "^2.8.0" [tool.poetry.group.dev] diff --git a/src/backend/alembic/versions/f5819b10ef2a_.py b/src/backend/alembic/versions/f5819b10ef2a_.py new file mode 100644 index 0000000000..785b38e02b --- /dev/null +++ b/src/backend/alembic/versions/f5819b10ef2a_.py @@ -0,0 +1,30 @@ +"""empty message + +Revision ID: f5819b10ef2a +Revises: 3247f8fd3f71 +Create Date: 2024-06-06 16:13:32.066454 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "f5819b10ef2a" +down_revision: Union[str, None] = "3247f8fd3f71" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("files", sa.Column("file_content", sa.String(), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("files", "file_content") + # ### end Alembic commands ### diff --git a/src/backend/chat/collate.py b/src/backend/chat/collate.py index a7b973e091..a25f7494e0 100644 --- a/src/backend/chat/collate.py +++ b/src/backend/chat/collate.py @@ -3,78 +3,122 @@ from backend.model_deployments.base import BaseDeployment - -def combine_documents( - documents: Dict[str, List[Dict[str, Any]]], - model: BaseDeployment, -) -> List[Dict[str, Any]]: - """ - Combines documents from different retrievers and reranks them. - - Args: - documents (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents. - model (BaseDeployment): Model deployment. - - Returns: - List[Dict[str, Any]]: List of combined documents. - """ - reranked_documents = rerank(documents, model) - return interleave(reranked_documents) +RELEVANCE_THRESHOLD = 0.5 -def rerank( - documents_by_query: Dict[str, List[Dict[str, Any]]], model: BaseDeployment -) -> Dict[str, List[Dict[str, Any]]]: +def rerank_and_chunk( + tool_results: List[Dict[str, Any]], model: BaseDeployment +) -> List[Dict[str, Any]]: """ - Takes a dictionary from queries of lists of documents and - internally rerank the documents for each query e.g: + Takes a list of tool_results and internally reranks the documents for each query, if there's one e.g: [{"q1":[1, 2, 3],"q2": [4, 5, 6]] -> [{"q1":[2 , 3, 1],"q2": [4, 6, 5]] Args: - documents_by_query (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents. + tool_results (List[Dict[str, Any]]): List of tool_results from different retrievers. + Each tool_result contains a ToolCall and a list of Outputs. model (BaseDeployment): Model deployment. Returns: - Dict[str, List[Dict[str, Any]]]: Dictionary from queries of lists of reranked documents. + List[Dict[str, Any]]: List of reranked and combined documents. """ # If rerank is not enabled return documents as is: if not model.rerank_enabled: - return documents_by_query + return tool_results + + # Merge all the documents with the same tool call and parameters + unified_tool_results = {} + for tool_result in tool_results: + tool_call = tool_result["call"] + tool_call_hashable = str(tool_call) + + if tool_call_hashable not in unified_tool_results.keys(): + unified_tool_results[tool_call_hashable] = { + "call": tool_call, + "outputs": [], + } + + unified_tool_results[tool_call_hashable]["outputs"].extend( + tool_result["outputs"] + ) + + # Rerank the documents for each query + reranked_results = {} + for tool_call_hashable, tool_result in unified_tool_results.items(): + tool_call = tool_result["call"] + query = tool_call.parameters.get("query") or tool_call.parameters.get( + "search_query" + ) + + # Only rerank if there is a query + if not query: + reranked_results[tool_call_hashable] = tool_result + continue + + chunked_outputs = [] + for output in tool_result["outputs"]: + text = output.get("text") - # rerank the documents by each query - all_rerank_docs = {} - for query, documents in documents_by_query.items(): - # Only rerank on text of document - # TODO handle no text in document - docs_to_rerank = [doc["text"] for doc in documents] + if not text: + chunked_outputs.append([output]) + continue + + chunks = chunk(text) + chunked_outputs.extend([dict(output, text=chunk) for chunk in chunks]) # If no documents to rerank, continue to the next query - if not docs_to_rerank: + if not chunked_outputs: continue - res = model.invoke_rerank(query=query, documents=docs_to_rerank) + res = model.invoke_rerank(query=query, documents=chunked_outputs) + # Sort the results by relevance score res.results.sort(key=lambda x: x.relevance_score, reverse=True) - # Map the results back to the original documents - all_rerank_docs[query] = [documents[r.index] for r in res.results] - - return all_rerank_docs - - -def interleave(documents: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]: - """ - Takes a dictionary from queries of lists of documents and interleaves them - for example [{"q1":[1, 2, 3],"q2": [4, 5, 6]] -> [1, 4, 2, 5, 3, 6] - - Args: - documents (Dict[str, List[Dict[str, Any]]]): Dictionary from queries of lists of documents. - Returns: - List[Dict[str, Any]]: List of interleaved documents. - """ - return [ - y - for x in zip_longest(*documents.values(), fillvalue=None) - for y in x - if y is not None - ] + # Map the results back to the original documents + reranked_results[tool_call_hashable] = { + "call": tool_call, + "outputs": [ + chunked_outputs[r.index] + for r in res.results + if r.relevance_score > RELEVANCE_THRESHOLD + ], + } + + return list(reranked_results.values()) + + +def chunk(content, compact_mode=False, soft_word_cut_off=100, hard_word_cut_off=300): + if compact_mode: + content = content.replace("\n", " ") + + chunks = [] + current_chunk = "" + words = content.split() + word_count = 0 + + for word in words: + if word_count + len(word.split()) > hard_word_cut_off: + # If adding the next word exceeds the hard limit, finalize the current chunk + chunks.append(current_chunk) + current_chunk = "" + word_count = 0 + + if word_count + len(word.split()) > soft_word_cut_off and word.endswith("."): + # If adding the next word exceeds the soft limit and the word ends with a period, finalize the current chunk + current_chunk += " " + word + chunks.append(current_chunk.strip()) + current_chunk = "" + word_count = 0 + else: + # Add the word to the current chunk + if current_chunk == "": + current_chunk = word + else: + current_chunk += " " + word + word_count += len(word.split()) + + # Add any remaining content as the last chunk + if current_chunk != "": + chunks.append(current_chunk.strip()) + + return chunks diff --git a/src/backend/chat/custom/custom.py b/src/backend/chat/custom/custom.py index 7285c95e92..470c2c1d4d 100644 --- a/src/backend/chat/custom/custom.py +++ b/src/backend/chat/custom/custom.py @@ -1,13 +1,17 @@ import logging -from typing import Any +from itertools import tee +from typing import Any, Dict, Generator, List from fastapi import HTTPException from backend.chat.base import BaseChat -from backend.chat.collate import combine_documents +from backend.chat.collate import rerank_and_chunk from backend.chat.custom.utils import get_deployment +from backend.chat.enums import StreamEvent from backend.config.tools import AVAILABLE_TOOLS, ToolName +from backend.crud.file import get_files_by_conversation_id from backend.model_deployments.base import BaseDeployment +from backend.schemas.chat import ChatMessage from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.tool import Category, Tool from backend.services.logger import get_logger @@ -38,133 +42,175 @@ def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: status_code=400, detail="Both tools and documents cannot be provided." ) - if kwargs.get("managed_tools", True): - # Generate Search Queries - chat_history = [message.to_dict() for message in chat_request.chat_history] - - function_tools: list[Tool] = [] - for tool in chat_request.tools: - available_tool = AVAILABLE_TOOLS.get(tool.name) - if available_tool and available_tool.category == Category.Function: - function_tools.append(Tool(**available_tool.model_dump())) - - if len(function_tools) > 0: - tool_results = self.get_tool_results( - chat_request.message, function_tools, deployment_model - ) - - chat_request.tools = None - if kwargs.get("stream", True) is True: - return deployment_model.invoke_chat_stream( - chat_request, - tool_results=tool_results, - ) - else: - return deployment_model.invoke_chat( - chat_request, - tool_results=tool_results, - ) + # If a direct answer is generated instead of tool calls, the chat will not be called again + # Instead, the direct answer will be returned from the stream + stream = self.handle_managed_tools(chat_request, deployment_model, **kwargs) - queries = deployment_model.invoke_search_queries( - chat_request.message, chat_history - ) - logger.info(f"Search queries generated: {queries}") + first_event, generated_direct_answer = next(stream) - # Fetch Documents - retrievers = self.get_retrievers( - kwargs.get("file_paths", []), [tool.name for tool in chat_request.tools] - ) - logger.info( - f"Using retrievers: {[retriever.__class__.__name__ for retriever in retrievers]}" + if generated_direct_answer: + yield first_event + for event, _ in stream: + yield event + else: + chat_request = first_event + invoke_method = ( + deployment_model.invoke_chat_stream + if kwargs.get("stream", True) + else deployment_model.invoke_chat ) - # No search queries were generated but retrievers were selected, use user message as query - if len(queries) == 0 and len(retrievers) > 0: - queries = [chat_request.message] - - all_documents = {} - # TODO: call in parallel and error handling - # TODO: merge with regular function tools after multihop implemented - for retriever in retrievers: - for query in queries: - parameters = {"query": query} - all_documents.setdefault(query, []).extend( - retriever.call(parameters) - ) - - # Collate Documents - documents = combine_documents(all_documents, deployment_model) - chat_request.documents = documents - chat_request.tools = [] - - # Generate Response - if kwargs.get("stream", True) is True: - return deployment_model.invoke_chat_stream(chat_request) - else: - return deployment_model.invoke_chat(chat_request) + yield from invoke_method(chat_request) - def get_retrievers( - self, file_paths: list[str], req_tools: list[ToolName] - ) -> list[Any]: + def handle_managed_tools( + self, + chat_request: CohereChatRequest, + deployment_model: BaseDeployment, + **kwargs: Any, + ) -> Generator[Any, None, None]: """ - Get retrievers for the required tools. + This function handles the managed tools. Args: - file_paths (list[str]): File paths. - req_tools (list[str]): Required tools. + chat_request (CohereChatRequest): The chat request + deployment_model (BaseDeployment): The deployment model + **kwargs (Any): The keyword arguments Returns: - list[Any]: Retriever implementations. + Generator[Any, None, None]: The tool results or the chat response, and a boolean indicating if a direct answer was generated """ - retrievers = [] - - # If no tools are required, return an empty list - if not req_tools: - return retrievers - - # Iterate through the required tools and check if they are available - # If so, add the implementation to the list of retrievers - # If not, raise an HTTPException - for tool_name in req_tools: - tool = AVAILABLE_TOOLS.get(tool_name) - if tool is None: - raise HTTPException( - status_code=404, detail=f"Tool {tool_name} not found." - ) - - # Check if the tool is visible, if not, skip it - if not tool.is_visible: - continue - - if tool.category == Category.FileLoader and file_paths is not None: - for file_path in file_paths: - retrievers.append(tool.implementation(file_path, **tool.kwargs)) - elif tool.category != Category.FileLoader: - retrievers.append(tool.implementation(**tool.kwargs)) - - return retrievers + tools = [ + Tool(**AVAILABLE_TOOLS.get(tool.name).model_dump()) + for tool in chat_request.tools + if AVAILABLE_TOOLS.get(tool.name) + ] + + if not tools: + yield chat_request, False + + for event, should_return in self.get_tool_results( + chat_request.message, + chat_request.chat_history, + tools, + kwargs.get("conversation_id"), + deployment_model, + kwargs, + ): + if should_return: + yield event, True + else: + chat_request.tool_results = event + chat_request.tools = tools + yield chat_request, False def get_tool_results( - self, message: str, tools: list[Tool], model: BaseDeployment - ) -> list[dict[str, Any]]: - tool_results = [] - tools_to_use = model.invoke_tools(message, tools) + self, + message: str, + chat_history: List[Dict[str, str]], + tools: list[Tool], + conversation_id: str, + deployment_model: BaseDeployment, + kwargs: Any, + ) -> Any: + """ + Invokes the tools and returns the results. If no tools calls are generated, it returns the chat response + as a direct answer. - tool_calls = tools_to_use.tool_calls if tools_to_use.tool_calls else [] - for tool_call in tool_calls: - tool = AVAILABLE_TOOLS.get(tool_call.name) - if not tool: - logging.warning(f"Couldn't find tool {tool_call.name}") - continue + Args: + message (str): The message to be processed + chat_history (List[Dict[str, str]]): The chat history + tools (list[Tool]): The tools to be invoked + conversation_id (str): The conversation ID + deployment_model (BaseDeployment): The deployment model + kwargs (Any): The keyword arguments + + Returns: + Any: The tool results or the chat response, and a boolean indicating if a direct answer was generated - outputs = tool.implementation().call( - parameters=tool_call.parameters, + """ + tool_results = [] + + # If the tool is Read_File or SearchFile, add the available files to the chat history + # so that the model knows what files are available + tool_names = [tool.name for tool in tools] + if ToolName.Read_File in tool_names or ToolName.Search_File in tool_names: + chat_history = self.add_files_to_chat_history( + chat_history, + conversation_id, + kwargs.get("session"), + kwargs.get("user_id"), ) - # If the tool returns a list of outputs, append each output to the tool_results list - # Otherwise, append the single output to the tool_results list - outputs = outputs if isinstance(outputs, list) else [outputs] - for output in outputs: - tool_results.append({"call": tool_call, "outputs": [output]}) + logger.info(f"Invoking tools: {tools}") + stream = deployment_model.invoke_tools( + message, tools, chat_history=chat_history + ) + + # Invoke tools can return a direct answer or a stream of events with the tool calls + # If one of the events is a tool call generation, the tools are invoked, and the results are returned + # Otherwise, the chat response is returned as a direct answer + stream, stream_copy = tee(stream) + + tool_call_found = False + for event in stream: + if event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION: + tool_call_found = True + tool_calls = event["tool_calls"] + + logger.info(f"Tool calls: {tool_calls}") + + # TODO: parallelize tool calls + for tool_call in tool_calls: + tool = AVAILABLE_TOOLS.get(tool_call.name) + if not tool: + logging.warning(f"Couldn't find tool {tool_call.name}") + continue + + outputs = tool.implementation().call( + parameters=tool_call.parameters, + session=kwargs.get("session"), + model_deployment=deployment_model, + user_id=kwargs.get("user_id"), + ) - return tool_results + # If the tool returns a list of outputs, append each output to the tool_results list + # Otherwise, append the single output to the tool_results list + outputs = outputs if isinstance(outputs, list) else [outputs] + for output in outputs: + tool_results.append({"call": tool_call, "outputs": [output]}) + + tool_results = rerank_and_chunk(tool_results, deployment_model) + logger.info(f"Tool results: {tool_results}") + yield tool_results, False + break + + if not tool_call_found: + for event in stream_copy: + yield event, True + + def add_files_to_chat_history( + self, + chat_history: List[Dict[str, str]], + conversation_id: str, + session: Any, + user_id: str, + ) -> List[Dict[str, str]]: + if session is None or conversation_id is None or len(conversation_id) == 0: + return chat_history + + available_files = get_files_by_conversation_id( + session, conversation_id, user_id + ) + files_message = "The user uploaded the following attachments:\n" + + for file in available_files: + word_count = len(file.file_content.split()) + + # Use the first 25 words as the document preview in the preamble + num_words = min(25, word_count) + preview = " ".join(file.file_content.split()[:num_words]) + + files_message += f"Filename: {file.file_name}\nWord Count: {word_count} Preview: {preview}\n\n" + + chat_history.append(ChatMessage(message=files_message, role="SYSTEM")) + return chat_history diff --git a/src/backend/config/tools.py b/src/backend/config/tools.py index 01247222e3..aadc569d99 100644 --- a/src/backend/config/tools.py +++ b/src/backend/config/tools.py @@ -6,9 +6,10 @@ from backend.schemas.tool import Category, ManagedTool from backend.tools import ( Calculator, - LangChainVectorDBRetriever, LangChainWikiRetriever, PythonInterpreter, + ReadFileTool, + SearchFileTool, TavilyInternetSearch, ) @@ -26,10 +27,11 @@ class ToolName(StrEnum): Wiki_Retriever_LangChain = "Wikipedia" - File_Upload_Langchain = "File Reader" + Search_File = "search_file" + Read_File = "read_document" Python_Interpreter = "Python_Interpreter" Calculator = "Calculator" - Tavily_Internet_Search = "Internet Search" + Tavily_Internet_Search = "Internet_Search" ALL_TOOLS = { @@ -50,21 +52,42 @@ class ToolName(StrEnum): category=Category.DataLoader, description="Retrieves documents from Wikipedia using LangChain.", ), - ToolName.File_Upload_Langchain: ManagedTool( - name=ToolName.File_Upload_Langchain, - implementation=LangChainVectorDBRetriever, + ToolName.Search_File: ManagedTool( + name=ToolName.Search_File, + implementation=SearchFileTool, parameter_definitions={ - "query": { - "description": "Query for retrieval.", + "search_query": { + "description": "Textual search query to search over the file's content for", + "type": "str", + "required": True, + }, + "filenames": { + "description": "A list of one or more uploaded filename strings to search over", + "type": "list", + "required": True, + }, + }, + is_visible=True, + is_available=SearchFileTool.is_available(), + error_message="SearchFileTool not available.", + category=Category.FileLoader, + description="Performs a search over a list of one or more of the attached files for a textual search query", + ), + ToolName.Read_File: ManagedTool( + name=ToolName.Read_File, + implementation=ReadFileTool, + parameter_definitions={ + "filename": { + "description": "The name of the attached file to read.", "type": "str", "required": True, } }, is_visible=True, - is_available=LangChainVectorDBRetriever.is_available(), - error_message="LangChainVectorDBRetriever not available, please make sure to set the COHERE_API_KEY environment variable.", + is_available=ReadFileTool.is_available(), + error_message="ReadFileTool not available.", category=Category.FileLoader, - description="Retrieves documents from a file using LangChain.", + description="Returns the textual contents of an uploaded file, broken up in text chunks.", ), ToolName.Python_Interpreter: ManagedTool( name=ToolName.Python_Interpreter, diff --git a/src/backend/crud/file.py b/src/backend/crud/file.py index fa971d86f3..e063942ebb 100644 --- a/src/backend/crud/file.py +++ b/src/backend/crud/file.py @@ -90,6 +90,27 @@ def get_files_by_ids(db: Session, file_ids: list[str], user_id: str) -> list[Fil return db.query(File).filter(File.id.in_(file_ids), File.user_id == user_id).all() +def get_files_by_file_names( + db: Session, file_names: list[str], user_id: str +) -> list[File]: + """ + Get files by file names. + + Args: + db (Session): Database session. + file_names (list[str]): File names. + user_id (str): User ID. + + Returns: + list[File]: List of files with the given file names. + """ + return ( + db.query(File) + .filter(File.file_name.in_(file_names), File.user_id == user_id) + .all() + ) + + def update_file(db: Session, file: File, new_file: UpdateFile) -> File: """ Update a file by ID. diff --git a/src/backend/database_models/file.py b/src/backend/database_models/file.py index dce6436ae7..b11338ca6c 100644 --- a/src/backend/database_models/file.py +++ b/src/backend/database_models/file.py @@ -19,6 +19,7 @@ class File(Base): file_name: Mapped[str] file_path: Mapped[str] file_size: Mapped[int] = mapped_column(default=0) + file_content: Mapped[str] = mapped_column(default="") __table_args__ = ( Index("file_conversation_id_user_id", conversation_id, user_id), diff --git a/src/backend/model_deployments/azure.py b/src/backend/model_deployments/azure.py index 87c43a328d..a3b78e0dff 100644 --- a/src/backend/model_deployments/azure.py +++ b/src/backend/model_deployments/azure.py @@ -52,7 +52,7 @@ def is_available(cls) -> bool: return all([os.environ.get(var) is not None for var in AZURE_ENV_VARS]) def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: - return self.client.chat( + yield self.client.chat( **chat_request.model_dump(exclude={"stream"}), **kwargs, ) @@ -90,5 +90,16 @@ def invoke_rerank( ) -> Any: return None - def invoke_tools(self, message: str, tools: List[Any], **kwargs: Any) -> List[Any]: - return self.client.chat(message=message, tools=tools, **kwargs) + def invoke_tools( + self, + message: str, + tools: List[Any], + chat_history: List[Dict[str, str]] | None = None, + **kwargs: Any, + ) -> Generator[StreamedChatResponse, None, None]: + stream = self.client.chat_stream( + message=message, tools=tools, chat_history=chat_history, **kwargs + ) + + for event in stream: + yield event.__dict__ diff --git a/src/backend/model_deployments/base.py b/src/backend/model_deployments/base.py index e93bbb57a9..f44e2bbc13 100644 --- a/src/backend/model_deployments/base.py +++ b/src/backend/model_deployments/base.py @@ -50,4 +50,10 @@ def invoke_rerank( ) -> Any: ... @abstractmethod - def invoke_tools(self, message: str, tools: List[Any], **kwargs: Any) -> Any: ... + def invoke_tools( + self, + message: str, + tools: List[Any], + chat_history: List[Dict[str, str]] | None = None, + **kwargs: Any + ) -> Generator[StreamedChatResponse, None, None]: ... diff --git a/src/backend/model_deployments/bedrock.py b/src/backend/model_deployments/bedrock.py index f41defabec..b2dbb03f38 100644 --- a/src/backend/model_deployments/bedrock.py +++ b/src/backend/model_deployments/bedrock.py @@ -58,7 +58,7 @@ def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: exclude={"tools", "conversation_id", "model", "stream"}, exclude_none=True ) - return self.client.chat( + yield self.client.chat( **bedrock_chat_req, **kwargs, ) @@ -101,5 +101,11 @@ def invoke_rerank( ) -> Any: return None - def invoke_tools(self, message: str, tools: List[Any], **kwargs: Any) -> List[Any]: + def invoke_tools( + self, + message: str, + tools: List[Any], + chat_history: List[Dict[str, str]] | None = None, + **kwargs: Any, + ) -> Generator[StreamedChatResponse, None, None]: return None diff --git a/src/backend/model_deployments/cohere_platform.py b/src/backend/model_deployments/cohere_platform.py index 8a61f76b04..cfa89bd53a 100644 --- a/src/backend/model_deployments/cohere_platform.py +++ b/src/backend/model_deployments/cohere_platform.py @@ -58,8 +58,9 @@ def is_available(cls) -> bool: return all([os.environ.get(var) is not None for var in COHERE_ENV_VARS]) def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: - return self.client.chat( + yield self.client.chat( **chat_request.model_dump(exclude={"stream"}), + force_single_step=True, **kwargs, ) @@ -67,7 +68,8 @@ def invoke_chat_stream( self, chat_request: CohereChatRequest, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: stream = self.client.chat_stream( - **chat_request.model_dump(exclude={"stream"}), + **chat_request.model_dump(exclude={"stream", "file_ids"}), + force_single_step=True, **kwargs, ) for event in stream: @@ -98,7 +100,21 @@ def invoke_rerank( query=query, documents=documents, model="rerank-english-v2.0", **kwargs ) - def invoke_tools(self, message: str, tools: List[Any], **kwargs: Any) -> List[Any]: - return self.client.chat( - message=message, tools=tools, model="command-r", **kwargs + def invoke_tools( + self, + message: str, + tools: List[Any], + chat_history: List[Dict[str, str]] | None = None, + **kwargs: Any, + ) -> Generator[StreamedChatResponse, None, None]: + stream = self.client.chat_stream( + message=message, + tools=tools, + model="command-r", + force_single_step=True, + chat_history=chat_history, + **kwargs, ) + + for event in stream: + yield event.__dict__ diff --git a/src/backend/model_deployments/sagemaker.py b/src/backend/model_deployments/sagemaker.py index b9838b6f15..0d39ce8a98 100644 --- a/src/backend/model_deployments/sagemaker.py +++ b/src/backend/model_deployments/sagemaker.py @@ -95,7 +95,7 @@ def invoke_search_queries( self, message: str, chat_history: List[Dict[str, str]] | None = None, - **kwargs: Any + **kwargs: Any, ) -> list[str]: # Create the payload for the request json_params = { @@ -115,6 +115,15 @@ def invoke_rerank( ) -> Any: return None + def invoke_tools( + self, + message: str, + tools: List[Any], + chat_history: List[Dict[str, str]] | None = None, + **kwargs: Any, + ) -> Generator[StreamedChatResponse, None, None]: + return None + # This class iterates through each line of Sagemaker's response # https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/ class LineIterator: diff --git a/src/backend/routers/chat.py b/src/backend/routers/chat.py index 227d8f90e6..15569b805d 100644 --- a/src/backend/routers/chat.py +++ b/src/backend/routers/chat.py @@ -66,6 +66,9 @@ async def chat_stream( deployment_config=deployment_config, file_paths=file_paths, managed_tools=managed_tools, + session=session, + conversation_id=conversation_id, + user_id=user_id, ), response_message, conversation_id, diff --git a/src/backend/routers/conversation.py b/src/backend/routers/conversation.py index 58c533d367..591e4908c9 100644 --- a/src/backend/routers/conversation.py +++ b/src/backend/routers/conversation.py @@ -18,6 +18,7 @@ from backend.schemas.file import DeleteFile, File, ListFile, UpdateFile, UploadFile from backend.services.auth.utils import get_header_user_id from backend.services.file.service import FileService +from backend.tools.files import get_file_content router = APIRouter( prefix="/v1/conversations", @@ -203,6 +204,9 @@ async def upload_file( # Handle uploading File file_path = FileService().upload_file(file) + # Read file content + content = get_file_content(file_path) + # Raise exception if file wasn't uploaded if not file_path.exists(): raise HTTPException( @@ -216,6 +220,7 @@ async def upload_file( file_name=file_path.name, file_path=str(file_path), file_size=file_path.stat().st_size, + file_content=content, ) upload_file = file_crud.create_file(session, upload_file) diff --git a/src/backend/schemas/chat.py b/src/backend/schemas/chat.py index c3076a22ec..4ac80b8d01 100644 --- a/src/backend/schemas/chat.py +++ b/src/backend/schemas/chat.py @@ -12,10 +12,11 @@ class ChatRole(StrEnum): - """One of CHATBOT|USER to identify who the message is coming from.""" + """One of CHATBOT|USER|SYSTEM to identify who the message is coming from.""" CHATBOT = "CHATBOT" USER = "USER" + SYSTEM = "SYSTEM" class ChatCitationQuality(StrEnum): @@ -36,7 +37,7 @@ class ChatMessage(BaseModel): """A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's message.""" role: ChatRole = Field( - title="One of CHATBOT|USER to identify who the message is coming from.", + title="One of CHATBOT|USER|SYSTEM to identify who the message is coming from.", ) message: str = Field( title="Contents of the chat message.", diff --git a/src/backend/schemas/cohere_chat.py b/src/backend/schemas/cohere_chat.py index 3605df6c19..9051dbbadf 100644 --- a/src/backend/schemas/cohere_chat.py +++ b/src/backend/schemas/cohere_chat.py @@ -101,3 +101,7 @@ class CohereChatRequest(BaseChatRequest): default=CohereChatPromptTruncation.AUTO_PRESERVE_ORDER, title="Dictates how the prompt will be constructed. Defaults to 'AUTO_PRESERVE_ORDER'.", ) + tool_results: List[Dict[str, Any]] | None = Field( + default=None, + title="A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations.", + ) diff --git a/src/backend/services/chat.py b/src/backend/services/chat.py index cb41c73023..a89b0ffdd6 100644 --- a/src/backend/services/chat.py +++ b/src/backend/services/chat.py @@ -536,19 +536,23 @@ def generate_chat_response( Returns: NonStreamedChatResponse: Chat response. """ - + model_deployment_response = next(model_deployment_response) if not isinstance(model_deployment_response, dict): response = model_deployment_response.__dict__ else: response = model_deployment_response - chat_history = [ - ChatMessage( - role=message.role, - message=message.message, + chat_history = [] + for message in response.get("chat_history", []): + if not isinstance(message, dict): + message = message.__dict__ + + chat_history.append( + ChatMessage( + role=message["role"], + message=message["message"], + ) ) - for message in response.get("chat_history", []) - ] documents = [] if "documents" in response and response["documents"]: diff --git a/src/backend/tests/model_deployments/mock_deployments/mock_azure.py b/src/backend/tests/model_deployments/mock_deployments/mock_azure.py index 5db239909a..b07f4cfc02 100644 --- a/src/backend/tests/model_deployments/mock_deployments/mock_azure.py +++ b/src/backend/tests/model_deployments/mock_deployments/mock_azure.py @@ -28,7 +28,7 @@ def is_available(cls) -> bool: return True def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: - return { + event = { "text": "Hi! Hello there! How's it going?", "generation_id": "ca0f398e-f8c8-48f0-b093-12d1754d00ed", "citations": None, @@ -50,6 +50,8 @@ def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: }, } + yield event + def invoke_chat_stream( self, chat_request: CohereChatRequest, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: diff --git a/src/backend/tests/model_deployments/mock_deployments/mock_bedrock.py b/src/backend/tests/model_deployments/mock_deployments/mock_bedrock.py index 1a4e00d690..2b081e3aa0 100644 --- a/src/backend/tests/model_deployments/mock_deployments/mock_bedrock.py +++ b/src/backend/tests/model_deployments/mock_deployments/mock_bedrock.py @@ -24,6 +24,32 @@ def list_models(cls) -> List[str]: def is_available(cls) -> bool: return True + def invoke_chat( + self, chat_request: CohereChatRequest, **kwargs: Any + ) -> Generator[StreamedChatResponse, None, None]: + event = { + "text": "Hi! Hello there! How's it going?", + "generation_id": "ca0f398e-f8c8-48f0-b093-12d1754d00ed", + "citations": None, + "documents": None, + "is_search_required": None, + "search_queries": None, + "search_results": None, + "finish_reason": "MAX_TOKENS", + "tool_calls": None, + "chat_history": [ + {"role": "USER", "message": "Hello"}, + {"role": "CHATBOT", "message": "Hi! Hello there! How's it going?"}, + ], + "response_id": "7f2c0ab4-e0d0-4808-891e-d5c6362e407a", + "meta": { + "api_version": {"version": "1"}, + "billed_units": {"input_tokens": 1, "output_tokens": 10}, + "tokens": {"input_tokens": 67, "output_tokens": 10}, + }, + } + yield event + def invoke_chat_stream( self, chat_request: CohereChatRequest, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: diff --git a/src/backend/tests/model_deployments/mock_deployments/mock_cohere_platform.py b/src/backend/tests/model_deployments/mock_deployments/mock_cohere_platform.py index efc8f4a63a..347b524ca5 100644 --- a/src/backend/tests/model_deployments/mock_deployments/mock_cohere_platform.py +++ b/src/backend/tests/model_deployments/mock_deployments/mock_cohere_platform.py @@ -24,8 +24,10 @@ def list_models(cls) -> List[str]: def is_available(cls) -> bool: return True - def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: - return { + def invoke_chat( + self, chat_request: CohereChatRequest, **kwargs: Any + ) -> Generator[StreamedChatResponse, None, None]: + event = { "text": "Hi! Hello there! How's it going?", "generation_id": "ca0f398e-f8c8-48f0-b093-12d1754d00ed", "citations": None, @@ -46,6 +48,7 @@ def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: "tokens": {"input_tokens": 67, "output_tokens": 10}, }, } + yield event def invoke_chat_stream( self, chat_request: CohereChatRequest, **kwargs: Any diff --git a/src/backend/tests/model_deployments/test_azure.py b/src/backend/tests/model_deployments/test_azure.py index 68f726faaa..d31774bb0a 100644 --- a/src/backend/tests/model_deployments/test_azure.py +++ b/src/backend/tests/model_deployments/test_azure.py @@ -60,7 +60,6 @@ def test_non_streamed_chat( mock_available_model_deployments, ): deployment = mock_azure_deployment.return_value - deployment.invoke_chat = MagicMock() response = session_client_chat.post( "/v1/chat", headers={ @@ -72,26 +71,3 @@ def test_non_streamed_chat( assert response.status_code == 200 assert type(deployment) is MockAzureDeployment - deployment.invoke_chat.assert_called_once_with( - CohereChatRequest( - message="Hello", - chat_history=[], - conversation_id="", - documents=[], - model="command-r", - temperature=None, - k=None, - p=None, - preamble=None, - file_ids=None, - tools=[], - search_queries_only=False, - deployment=None, - max_tokens=10, - seed=None, - stop_sequences=None, - presence_penalty=None, - frequency_penalty=None, - prompt_truncation="AUTO_PRESERVE_ORDER", - ) - ) diff --git a/src/backend/tests/model_deployments/test_bedrock.py b/src/backend/tests/model_deployments/test_bedrock.py index 7b32f75147..1650591d49 100644 --- a/src/backend/tests/model_deployments/test_bedrock.py +++ b/src/backend/tests/model_deployments/test_bedrock.py @@ -50,3 +50,19 @@ def test_streamed_chat( prompt_truncation="AUTO_PRESERVE_ORDER", ) ) + + +def test_non_streamed_chat( + session_client_chat: TestClient, + user: User, + mock_bedrock_deployment, + mock_available_model_deployments, +): + deployment = mock_bedrock_deployment.return_value + response = session_client_chat.post( + "/v1/chat", + headers={"User-Id": user.id, "Deployment-Name": ModelDeploymentName.Bedrock}, + json={"message": "Hello", "max_tokens": 10}, + ) + + assert response.status_code == 200 diff --git a/src/backend/tests/model_deployments/test_cohere_platform.py b/src/backend/tests/model_deployments/test_cohere_platform.py index 6bca524788..db60c93c4f 100644 --- a/src/backend/tests/model_deployments/test_cohere_platform.py +++ b/src/backend/tests/model_deployments/test_cohere_platform.py @@ -60,7 +60,6 @@ def test_non_streamed_chat( mock_available_model_deployments, ): deployment = mock_cohere_deployment.return_value - deployment.invoke_chat = MagicMock() response = session_client_chat.post( "/v1/chat", headers={ @@ -72,26 +71,3 @@ def test_non_streamed_chat( assert response.status_code == 200 assert type(deployment) is MockCohereDeployment - deployment.invoke_chat.assert_called_once_with( - CohereChatRequest( - message="Hello", - chat_history=[], - conversation_id="", - documents=[], - model="command-r", - temperature=None, - k=None, - p=None, - preamble=None, - file_ids=None, - tools=[], - search_queries_only=False, - deployment=None, - max_tokens=10, - seed=None, - stop_sequences=None, - presence_penalty=None, - frequency_penalty=None, - prompt_truncation="AUTO_PRESERVE_ORDER", - ) - ) diff --git a/src/backend/tests/model_deployments/test_sagemaker.py b/src/backend/tests/model_deployments/test_sagemaker.py index cfbcdcc8fc..a2f53c7398 100644 --- a/src/backend/tests/model_deployments/test_sagemaker.py +++ b/src/backend/tests/model_deployments/test_sagemaker.py @@ -50,6 +50,7 @@ def test_streamed_chat( ) +@pytest.mark.skip("Non-streamed chat is not supported for SageMaker yet") def test_non_streamed_chat( session_client_chat: TestClient, user: User, @@ -57,7 +58,6 @@ def test_non_streamed_chat( mock_available_model_deployments, ): deployment = mock_sagemaker_deployment.return_value - deployment.invoke_chat = MagicMock() response = session_client_chat.post( "/v1/chat", headers={"User-Id": user.id, "Deployment-Name": ModelDeploymentName.SageMaker}, @@ -65,27 +65,3 @@ def test_non_streamed_chat( ) assert response.status_code == 200 - assert type(deployment) is MockSageMakerDeployment - deployment.invoke_chat.assert_called_once_with( - CohereChatRequest( - message="Hello", - chat_history=[], - conversation_id="", - documents=[], - model="command-r", - temperature=None, - k=None, - p=None, - preamble=None, - file_ids=None, - tools=[], - search_queries_only=False, - deployment=None, - max_tokens=10, - seed=None, - stop_sequences=None, - presence_penalty=None, - frequency_penalty=None, - prompt_truncation="AUTO_PRESERVE_ORDER", - ) - ) diff --git a/src/backend/tests/tools/test_collate.py b/src/backend/tests/tools/test_collate.py index c79d0a6bfe..6f0455cfe3 100644 --- a/src/backend/tests/tools/test_collate.py +++ b/src/backend/tests/tools/test_collate.py @@ -4,6 +4,7 @@ from backend.chat import collate from backend.model_deployments import CohereDeployment +from backend.schemas.tool import ToolCall is_cohere_env_set = ( os.environ.get("COHERE_API_KEY") is not None @@ -14,30 +15,62 @@ @pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_rerank() -> None: model = CohereDeployment(model_config={}) - input = { - "mountain": [{"text": "hill"}, {"text": "cable"}, {"text": "goat"}], - "computer": [{"text": "software"}, {"text": "penguin"}, {"text": "cable"}], - } - assert collate.rerank(input, model) == { - "mountain": [{"text": "hill"}, {"text": "goat"}, {"text": "cable"}], - "computer": [{"text": "cable"}, {"text": "software"}, {"text": "penguin"}], - } - - -def test_interleave() -> None: - input = { - "q1": [{"q1a": "a"}, {"q1b": "b"}, {"q1c": "c"}], - "q2": [{"q2a": "a"}, {"q2b": "b"}, {"q2c": "c"}], - "q3": [{"q3a": "a"}, {"q3b": "b"}, {"q3c": "c"}], - } - assert collate.interleave(input) == [ - {"q1a": "a"}, - {"q2a": "a"}, - {"q3a": "a"}, - {"q1b": "b"}, - {"q2b": "b"}, - {"q3b": "b"}, - {"q1c": "c"}, - {"q2c": "c"}, - {"q3c": "c"}, + tool_results = [ + { + "call": ToolCall(parameters={"query": "mountain"}, name="retriever"), + "outputs": [{"text": "hill"}, {"text": "goat"}, {"text": "cable"}], + }, + { + "call": ToolCall(parameters={"query": "computer"}, name="retriever"), + "outputs": [{"text": "cable"}, {"text": "software"}, {"text": "penguin"}], + }, ] + + expected_output = [ + { + "call": ToolCall(name="retriever", parameters={"query": "mountain"}), + "outputs": [], + }, + { + "call": ToolCall(name="retriever", parameters={"query": "computer"}), + "outputs": [], + }, + ] + + assert collate.rerank_and_chunk(tool_results, model) == expected_output + + +def test_chunk_normal_mode() -> None: + content = "This is a test. We are testing the chunk function." + expected_output = ["This is a test.", "We are testing the chunk function."] + collate.chunk(content, False, 4, 10) == expected_output + + +def test_chunk_compact_mode() -> None: + content = "This is a test.\nWe are testing the chunk function." + expected_output = ["This is a test.", "We are testing the chunk function."] + collate.chunk(content, True, 4, 10) == expected_output + + +def test_chunk_hard_cut_off() -> None: + content = "This is a test. We are testing the chunk function. This sentence will exceed the hard cut off." + expected_output = [ + "This is a test. We are testing the chunk function.", + "This sentence will exceed the hard cut off.", + ] + collate.chunk(content, False, 4, 10) == expected_output + + +def test_chunk_soft_cut_off() -> None: + content = "This is a test. We are testing the chunk function. This sentence will exceed the soft cut off." + expected_output = [ + "This is a test.", + "We are testing the chunk function. This sentence will exceed the soft cut off.", + ] + collate.chunk(content, False, 4, 10) == expected_output + + +def test_chunk_empty_content() -> None: + content = "" + expected_output = [] + collate.chunk(content, False, 4, 10) == expected_output diff --git a/src/backend/tools/__init__.py b/src/backend/tools/__init__.py index d3b83617e1..54090cade4 100644 --- a/src/backend/tools/__init__.py +++ b/src/backend/tools/__init__.py @@ -1,4 +1,5 @@ from backend.tools.calculator import Calculator +from backend.tools.files import ReadFileTool, SearchFileTool from backend.tools.lang_chain import LangChainVectorDBRetriever, LangChainWikiRetriever from backend.tools.python_interpreter import PythonInterpreter from backend.tools.tavily import TavilyInternetSearch @@ -9,4 +10,6 @@ "LangChainVectorDBRetriever", "LangChainWikiRetriever", "TavilyInternetSearch", + "ReadFileTool", + "SearchFileTool", ] diff --git a/src/backend/tools/files.py b/src/backend/tools/files.py new file mode 100644 index 0000000000..7804fdec1c --- /dev/null +++ b/src/backend/tools/files.py @@ -0,0 +1,95 @@ +from typing import Any, Dict, List + +from pypdf import PdfReader + +import backend.crud.file as file_crud +from backend.tools.base import BaseTool + + +class ReadFileTool(BaseTool): + """ + This class reads a file from the file system. + """ + + MAX_NUM_CHUNKS = 10 + + def __init__(self): + pass + + @classmethod + def is_available(cls) -> bool: + return True + + def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: + file_name = parameters.get("filename", "") + session = kwargs.get("session") + user_id = kwargs.get("user_id") + + if not file_name: + return [] + + files = file_crud.get_files_by_file_names(session, [file_name], user_id) + + if not files: + return [] + + file = files[0] + return [ + { + "text": file.file_content, + "title": file.file_name, + "url": file.file_path, + } + ] + + +class SearchFileTool(BaseTool): + """ + This class searches for a query in a file. + """ + + MAX_NUM_CHUNKS = 10 + + def __init__(self): + pass + + @classmethod + def is_available(cls) -> bool: + return True + + def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: + query = parameters.get("search_query") + file_names = parameters.get("filenames") + model_deployment = kwargs.get("model_deployment") + session = kwargs.get("session") + user_id = kwargs.get("user_id") + + if not query or not file_names: + return [] + + files = file_crud.get_files_by_file_names(session, file_names, user_id) + + if not files: + return [] + + results = [] + for file in files: + results.append( + { + "text": file.file_content, + "title": file.file_name, + "url": file.file_path, + } + ) + + return results + + +def get_file_content(file_path): + # Currently only supports PDF files + loader = PdfReader(file_path) + text = "" + for page in loader.pages: + text += page.extract_text() + "\n" + + return text