From ab89fe7f8dea9bc9a8ab628dea2e51a2b154598a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADsa=20Moura?= Date: Fri, 12 Jul 2024 11:51:53 +0100 Subject: [PATCH] Test sync code (#410) --- src/backend/chat/collate.py | 4 +- src/backend/chat/custom/custom.py | 125 +++++++++--------- src/backend/main.py | 4 +- .../model_deployments/cohere_platform.py | 18 +-- src/backend/services/chat.py | 94 ++++++------- src/backend/tools/python_interpreter.py | 2 +- src/backend/tools/tavily.py | 8 +- 7 files changed, 122 insertions(+), 133 deletions(-) diff --git a/src/backend/chat/collate.py b/src/backend/chat/collate.py index 6d13fc0ee0..8ec29400c8 100644 --- a/src/backend/chat/collate.py +++ b/src/backend/chat/collate.py @@ -7,7 +7,7 @@ RELEVANCE_THRESHOLD = 0.1 -async def rerank_and_chunk( +def rerank_and_chunk( tool_results: List[Dict[str, Any]], model: BaseDeployment, **kwargs: Any ) -> List[Dict[str, Any]]: """ @@ -74,7 +74,7 @@ async def rerank_and_chunk( if not chunked_outputs: continue - res = await model.invoke_rerank( + res = model.invoke_rerank( query=query, documents=chunked_outputs, trace_id=trace_id, diff --git a/src/backend/chat/custom/custom.py b/src/backend/chat/custom/custom.py index 9ef0a442c7..f97a306726 100644 --- a/src/backend/chat/custom/custom.py +++ b/src/backend/chat/custom/custom.py @@ -22,7 +22,7 @@ class CustomChat(BaseChat): """Custom chat flow not using integrations for models.""" - async def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: + def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: """ Chat flow for custom models. @@ -54,33 +54,32 @@ async def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: try: stream = self.call_chat(self.chat_request, deployment_model, **kwargs) - async with AsyncGeneratorContextManager(stream) as stream: - async for event in stream: + for event in stream: + send_log_message( + logger, + f"Stream event: {event}", + level="info", + conversation_id=kwargs.get("conversation_id"), + user_id=kwargs.get("user_id"), + ) + result = self.handle_event(event, chat_request) + + if result: + yield result + + if event[ + "event_type" + ] == StreamEvent.STREAM_END and self.is_final_event( + event, chat_request + ): send_log_message( logger, - f"Stream event: {event}", + f"Final event: {event}", level="info", conversation_id=kwargs.get("conversation_id"), user_id=kwargs.get("user_id"), ) - result = self.handle_event(event, chat_request) - - if result: - yield result - - if event[ - "event_type" - ] == StreamEvent.STREAM_END and self.is_final_event( - event, chat_request - ): - send_log_message( - logger, - f"Final event: {event}", - level="info", - conversation_id=kwargs.get("conversation_id"), - user_id=kwargs.get("user_id"), - ) - break + break except Exception as e: yield { "event_type": StreamEvent.STREAM_END, @@ -138,7 +137,7 @@ def is_not_direct_answer(self, event: Dict[str, Any]) -> bool: and "tool_calls" in event ) - async def call_chat(self, chat_request, deployment_model, **kwargs: Any): + def call_chat(self, chat_request, deployment_model, **kwargs: Any): trace_id = kwargs.get("trace_id", "") user_id = kwargs.get("user_id", "") agent_id = kwargs.get("agent_id", "") @@ -186,20 +185,17 @@ async def call_chat(self, chat_request, deployment_model, **kwargs: Any): # Invoke chat stream has_tool_calls = False - async with AsyncGeneratorContextManager( - deployment_model.invoke_chat_stream( - chat_request, trace_id=trace_id, user_id=user_id, agent_id=agent_id - ) - ) as chat_stream: - async for event in chat_stream: - if event["event_type"] == StreamEvent.STREAM_END: - chat_request.chat_history = event["response"].get( - "chat_history", [] - ) - elif event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION: - has_tool_calls = True + for event in deployment_model.invoke_chat_stream( + chat_request, trace_id=trace_id, user_id=user_id, agent_id=agent_id + ): + if event["event_type"] == StreamEvent.STREAM_END: + chat_request.chat_history = event["response"].get( + "chat_history", [] + ) + elif event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION: + has_tool_calls = True - yield event + yield event send_log_message( logger, @@ -212,7 +208,7 @@ async def call_chat(self, chat_request, deployment_model, **kwargs: Any): # Check for new tool calls in the chat history if has_tool_calls: # Handle tool calls - tool_results = await self.call_tools( + tool_results = self.call_tools( chat_request.chat_history, deployment_model, **kwargs ) @@ -234,7 +230,7 @@ def update_chat_history_with_tool_results( chat_request.chat_history.extend(tool_results) - async def call_tools(self, chat_history, deployment_model, **kwargs: Any): + def call_tools(self, chat_history, deployment_model, **kwargs: Any): tool_results = [] if "tool_calls" not in chat_history[-1]: return tool_results @@ -262,7 +258,7 @@ async def call_tools(self, chat_history, deployment_model, **kwargs: Any): if not tool: continue - outputs = await tool.implementation().call( + outputs = tool.implementation().call( parameters=tool_call.get("parameters"), session=kwargs.get("session"), model_deployment=deployment_model, @@ -277,7 +273,7 @@ async def call_tools(self, chat_history, deployment_model, **kwargs: Any): for output in outputs: tool_results.append({"call": tool_call, "outputs": [output]}) - tool_results = await rerank_and_chunk(tool_results, deployment_model, **kwargs) + tool_results = rerank_and_chunk(tool_results, deployment_model, **kwargs) send_log_message( logger, f"Tool results: {tool_results}", @@ -294,33 +290,32 @@ async def handle_tool_calls_stream(self, tool_results_stream): is_direct_answer = True chat_history = [] - async with AsyncGeneratorContextManager(stream) as chat_stream: - async for event in chat_stream: - if event["event_type"] == StreamEvent.STREAM_END: - stream_chat_history = [] - if "response" in event: - stream_chat_history = event["response"].get("chat_history", []) - elif "chat_history" in event: - stream_chat_history = event["chat_history"] - - for message in stream_chat_history: - if not isinstance(message, dict): - message = to_dict(message) - - chat_history.append( - ChatMessage( - role=message.get("role"), - message=message.get("message", ""), - tool_results=message.get("tool_results", None), - tool_calls=message.get("tool_calls", None), - ) + for event in stream: + if event["event_type"] == StreamEvent.STREAM_END: + stream_chat_history = [] + if "response" in event: + stream_chat_history = event["response"].get("chat_history", []) + elif "chat_history" in event: + stream_chat_history = event["chat_history"] + + for message in stream_chat_history: + if not isinstance(message, dict): + message = to_dict(message) + + chat_history.append( + ChatMessage( + role=message.get("role"), + message=message.get("message", ""), + tool_results=message.get("tool_results", None), + tool_calls=message.get("tool_calls", None), ) + ) - elif ( - event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION - and "tool_calls" in event - ): - is_direct_answer = False + elif ( + event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION + and "tool_calls" in event + ): + is_direct_answer = False return is_direct_answer, chat_history, stream_copy diff --git a/src/backend/main.py b/src/backend/main.py index b969626cbc..f8cf4d9bb0 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -73,8 +73,8 @@ def create_app(): allow_methods=["*"], allow_headers=["*"], ) - app.add_middleware(LoggingMiddleware) - app.add_middleware(MetricsMiddleware) + # app.add_middleware(LoggingMiddleware) + # app.add_middleware(MetricsMiddleware) return app diff --git a/src/backend/model_deployments/cohere_platform.py b/src/backend/model_deployments/cohere_platform.py index cf150a8e31..63d216b5a0 100644 --- a/src/backend/model_deployments/cohere_platform.py +++ b/src/backend/model_deployments/cohere_platform.py @@ -73,17 +73,15 @@ def list_models(cls) -> List[str]: def is_available(cls) -> bool: return all([os.environ.get(var) is not None for var in COHERE_ENV_VARS]) - @collect_metrics_chat - async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: + # @collect_metrics_chat + def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: response = self.client.chat( **chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}), ) yield to_dict(response) - @collect_metrics_chat_stream - async def invoke_chat_stream( - self, chat_request: CohereChatRequest, **kwargs: Any - ) -> AsyncGenerator[Any, Any]: + # @collect_metrics_chat_stream + def invoke_chat_stream(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: stream = self.client.chat_stream( **chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}), ) @@ -98,10 +96,12 @@ async def invoke_chat_stream( ) yield to_dict(event) - @collect_metrics_rerank - async def invoke_rerank( + # @collect_metrics_rerank + def invoke_rerank( self, query: str, documents: List[Dict[str, Any]], **kwargs: Any ) -> Any: - return self.client.rerank( + response = self.client.rerank( query=query, documents=documents, model=DEFAULT_RERANK_MODEL ) + + return to_dict(response) diff --git a/src/backend/services/chat.py b/src/backend/services/chat.py index 255a365c4d..830325a110 100644 --- a/src/backend/services/chat.py +++ b/src/backend/services/chat.py @@ -518,30 +518,27 @@ async def generate_chat_response( ) non_streamed_chat_response = None - async with AsyncGeneratorContextManager(stream) as chat_stream: - async for event in chat_stream: - event = json.loads(event) - if event["event"] == StreamEvent.STREAM_END: - data = event["data"] - response_id = response_message.id if response_message else None - generation_id = ( - response_message.generation_id if response_message else None - ) - - non_streamed_chat_response = NonStreamedChatResponse( - text=data.get("text", ""), - response_id=response_id, - generation_id=generation_id, - chat_history=data.get("chat_history", []), - finish_reason=data.get("finish_reason", ""), - citations=data.get("citations", []), - search_queries=data.get("search_queries", []), - documents=data.get("documents", []), - search_results=data.get("search_results", []), - event_type=StreamEvent.NON_STREAMED_CHAT_RESPONSE, - conversation_id=conversation_id, - tool_calls=data.get("tool_calls", []), - ) + for event in stream: + event = json.loads(event) + if event["event"] == StreamEvent.STREAM_END: + data = event["data"] + response_id = response_message.id if response_message else None + generation_id = response_message.generation_id if response_message else None + + non_streamed_chat_response = NonStreamedChatResponse( + text=data.get("text", ""), + response_id=response_id, + generation_id=generation_id, + chat_history=data.get("chat_history", []), + finish_reason=data.get("finish_reason", ""), + citations=data.get("citations", []), + search_queries=data.get("search_queries", []), + documents=data.get("documents", []), + search_results=data.get("search_results", []), + event_type=StreamEvent.NON_STREAMED_CHAT_RESPONSE, + conversation_id=conversation_id, + tool_calls=data.get("tool_calls", []), + ) return non_streamed_chat_response @@ -586,35 +583,32 @@ async def generate_chat_stream( document_ids_to_document = {} stream_event = None - async with AsyncGeneratorContextManager( - model_deployment_stream - ) as model_deployment_stream_tmp: - async for event in model_deployment_stream_tmp: - ( - stream_event, - stream_end_data, - response_message, - document_ids_to_document, - ) = handle_stream_event( - event, - conversation_id, - stream_end_data, - response_message, - document_ids_to_document, - session=session, - should_store=should_store, - user_id=user_id, - next_message_position=kwargs.get("next_message_position", 0), - ) + for event in model_deployment_stream: + ( + stream_event, + stream_end_data, + response_message, + document_ids_to_document, + ) = handle_stream_event( + event, + conversation_id, + stream_end_data, + response_message, + document_ids_to_document, + session=session, + should_store=should_store, + user_id=user_id, + next_message_position=kwargs.get("next_message_position", 0), + ) - yield json.dumps( - jsonable_encoder( - ChatResponseEvent( - event=stream_event.event_type.value, - data=stream_event, - ) + yield json.dumps( + jsonable_encoder( + ChatResponseEvent( + event=stream_event.event_type.value, + data=stream_event, ) ) + ) if should_store: update_conversation_after_turn( diff --git a/src/backend/tools/python_interpreter.py b/src/backend/tools/python_interpreter.py index d70dccd4d1..c1330118c2 100644 --- a/src/backend/tools/python_interpreter.py +++ b/src/backend/tools/python_interpreter.py @@ -29,7 +29,7 @@ class PythonInterpreter(BaseTool): def is_available(cls) -> bool: return cls.INTERPRETER_URL is not None - async def call(self, parameters: dict, **kwargs: Any): + def call(self, parameters: dict, **kwargs: Any): if not self.INTERPRETER_URL: raise Exception("Python Interpreter tool called while URL not set") diff --git a/src/backend/tools/tavily.py b/src/backend/tools/tavily.py index 9fe4e792af..511086a814 100644 --- a/src/backend/tools/tavily.py +++ b/src/backend/tools/tavily.py @@ -21,7 +21,7 @@ def __init__(self): def is_available(cls) -> bool: return cls.TAVILY_API_KEY is not None - async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: + def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: query = parameters.get("query", "") result = self.client.search( query=query, search_depth="advanced", include_raw_content=True @@ -49,7 +49,7 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: } expanded.append(new_result) - reranked_results = await self.rerank_page_snippets( + reranked_results = self.rerank_page_snippets( query, expanded, model=kwargs.get("model_deployment"), **kwargs ) @@ -58,7 +58,7 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: for result in reranked_results ] - async def rerank_page_snippets( + def rerank_page_snippets( self, query: str, snippets: List[Dict[str, Any]], @@ -72,7 +72,7 @@ async def rerank_page_snippets( relevance_scores = [None for _ in range(len(snippets))] for batch_start in range(0, len(snippets), rerank_batch_size): snippet_batch = snippets[batch_start : batch_start + rerank_batch_size] - batch_output = await model.invoke_rerank( + batch_output = model.invoke_rerank( query=query, documents=[ f"{snippet['title']} {snippet['content']}"