Skip to content

Commit

Permalink
Test sync code (#410)
Browse files Browse the repository at this point in the history
  • Loading branch information
lusmoura authored Jul 12, 2024
1 parent 9a31d17 commit ab89fe7
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 133 deletions.
4 changes: 2 additions & 2 deletions src/backend/chat/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
RELEVANCE_THRESHOLD = 0.1


async def rerank_and_chunk(
def rerank_and_chunk(
tool_results: List[Dict[str, Any]], model: BaseDeployment, **kwargs: Any
) -> List[Dict[str, Any]]:
"""
Expand Down Expand Up @@ -74,7 +74,7 @@ async def rerank_and_chunk(
if not chunked_outputs:
continue

res = await model.invoke_rerank(
res = model.invoke_rerank(
query=query,
documents=chunked_outputs,
trace_id=trace_id,
Expand Down
125 changes: 60 additions & 65 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
class CustomChat(BaseChat):
"""Custom chat flow not using integrations for models."""

async def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
"""
Chat flow for custom models.
Expand Down Expand Up @@ -54,33 +54,32 @@ async def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
try:
stream = self.call_chat(self.chat_request, deployment_model, **kwargs)

async with AsyncGeneratorContextManager(stream) as stream:
async for event in stream:
for event in stream:
send_log_message(
logger,
f"Stream event: {event}",
level="info",
conversation_id=kwargs.get("conversation_id"),
user_id=kwargs.get("user_id"),
)
result = self.handle_event(event, chat_request)

if result:
yield result

if event[
"event_type"
] == StreamEvent.STREAM_END and self.is_final_event(
event, chat_request
):
send_log_message(
logger,
f"Stream event: {event}",
f"Final event: {event}",
level="info",
conversation_id=kwargs.get("conversation_id"),
user_id=kwargs.get("user_id"),
)
result = self.handle_event(event, chat_request)

if result:
yield result

if event[
"event_type"
] == StreamEvent.STREAM_END and self.is_final_event(
event, chat_request
):
send_log_message(
logger,
f"Final event: {event}",
level="info",
conversation_id=kwargs.get("conversation_id"),
user_id=kwargs.get("user_id"),
)
break
break
except Exception as e:
yield {
"event_type": StreamEvent.STREAM_END,
Expand Down Expand Up @@ -138,7 +137,7 @@ def is_not_direct_answer(self, event: Dict[str, Any]) -> bool:
and "tool_calls" in event
)

async def call_chat(self, chat_request, deployment_model, **kwargs: Any):
def call_chat(self, chat_request, deployment_model, **kwargs: Any):
trace_id = kwargs.get("trace_id", "")
user_id = kwargs.get("user_id", "")
agent_id = kwargs.get("agent_id", "")
Expand Down Expand Up @@ -186,20 +185,17 @@ async def call_chat(self, chat_request, deployment_model, **kwargs: Any):

# Invoke chat stream
has_tool_calls = False
async with AsyncGeneratorContextManager(
deployment_model.invoke_chat_stream(
chat_request, trace_id=trace_id, user_id=user_id, agent_id=agent_id
)
) as chat_stream:
async for event in chat_stream:
if event["event_type"] == StreamEvent.STREAM_END:
chat_request.chat_history = event["response"].get(
"chat_history", []
)
elif event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION:
has_tool_calls = True
for event in deployment_model.invoke_chat_stream(
chat_request, trace_id=trace_id, user_id=user_id, agent_id=agent_id
):
if event["event_type"] == StreamEvent.STREAM_END:
chat_request.chat_history = event["response"].get(
"chat_history", []
)
elif event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION:
has_tool_calls = True

yield event
yield event

send_log_message(
logger,
Expand All @@ -212,7 +208,7 @@ async def call_chat(self, chat_request, deployment_model, **kwargs: Any):
# Check for new tool calls in the chat history
if has_tool_calls:
# Handle tool calls
tool_results = await self.call_tools(
tool_results = self.call_tools(
chat_request.chat_history, deployment_model, **kwargs
)

Expand All @@ -234,7 +230,7 @@ def update_chat_history_with_tool_results(

chat_request.chat_history.extend(tool_results)

async def call_tools(self, chat_history, deployment_model, **kwargs: Any):
def call_tools(self, chat_history, deployment_model, **kwargs: Any):
tool_results = []
if "tool_calls" not in chat_history[-1]:
return tool_results
Expand Down Expand Up @@ -262,7 +258,7 @@ async def call_tools(self, chat_history, deployment_model, **kwargs: Any):
if not tool:
continue

outputs = await tool.implementation().call(
outputs = tool.implementation().call(
parameters=tool_call.get("parameters"),
session=kwargs.get("session"),
model_deployment=deployment_model,
Expand All @@ -277,7 +273,7 @@ async def call_tools(self, chat_history, deployment_model, **kwargs: Any):
for output in outputs:
tool_results.append({"call": tool_call, "outputs": [output]})

tool_results = await rerank_and_chunk(tool_results, deployment_model, **kwargs)
tool_results = rerank_and_chunk(tool_results, deployment_model, **kwargs)
send_log_message(
logger,
f"Tool results: {tool_results}",
Expand All @@ -294,33 +290,32 @@ async def handle_tool_calls_stream(self, tool_results_stream):
is_direct_answer = True

chat_history = []
async with AsyncGeneratorContextManager(stream) as chat_stream:
async for event in chat_stream:
if event["event_type"] == StreamEvent.STREAM_END:
stream_chat_history = []
if "response" in event:
stream_chat_history = event["response"].get("chat_history", [])
elif "chat_history" in event:
stream_chat_history = event["chat_history"]

for message in stream_chat_history:
if not isinstance(message, dict):
message = to_dict(message)

chat_history.append(
ChatMessage(
role=message.get("role"),
message=message.get("message", ""),
tool_results=message.get("tool_results", None),
tool_calls=message.get("tool_calls", None),
)
for event in stream:
if event["event_type"] == StreamEvent.STREAM_END:
stream_chat_history = []
if "response" in event:
stream_chat_history = event["response"].get("chat_history", [])
elif "chat_history" in event:
stream_chat_history = event["chat_history"]

for message in stream_chat_history:
if not isinstance(message, dict):
message = to_dict(message)

chat_history.append(
ChatMessage(
role=message.get("role"),
message=message.get("message", ""),
tool_results=message.get("tool_results", None),
tool_calls=message.get("tool_calls", None),
)
)

elif (
event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION
and "tool_calls" in event
):
is_direct_answer = False
elif (
event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION
and "tool_calls" in event
):
is_direct_answer = False

return is_direct_answer, chat_history, stream_copy

Expand Down
4 changes: 2 additions & 2 deletions src/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def create_app():
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(LoggingMiddleware)
app.add_middleware(MetricsMiddleware)
# app.add_middleware(LoggingMiddleware)
# app.add_middleware(MetricsMiddleware)

return app

Expand Down
18 changes: 9 additions & 9 deletions src/backend/model_deployments/cohere_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,15 @@ def list_models(cls) -> List[str]:
def is_available(cls) -> bool:
return all([os.environ.get(var) is not None for var in COHERE_ENV_VARS])

@collect_metrics_chat
async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
# @collect_metrics_chat
def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
response = self.client.chat(
**chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}),
)
yield to_dict(response)

@collect_metrics_chat_stream
async def invoke_chat_stream(
self, chat_request: CohereChatRequest, **kwargs: Any
) -> AsyncGenerator[Any, Any]:
# @collect_metrics_chat_stream
def invoke_chat_stream(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
stream = self.client.chat_stream(
**chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}),
)
Expand All @@ -98,10 +96,12 @@ async def invoke_chat_stream(
)
yield to_dict(event)

@collect_metrics_rerank
async def invoke_rerank(
# @collect_metrics_rerank
def invoke_rerank(
self, query: str, documents: List[Dict[str, Any]], **kwargs: Any
) -> Any:
return self.client.rerank(
response = self.client.rerank(
query=query, documents=documents, model=DEFAULT_RERANK_MODEL
)

return to_dict(response)
94 changes: 44 additions & 50 deletions src/backend/services/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,30 +518,27 @@ async def generate_chat_response(
)

non_streamed_chat_response = None
async with AsyncGeneratorContextManager(stream) as chat_stream:
async for event in chat_stream:
event = json.loads(event)
if event["event"] == StreamEvent.STREAM_END:
data = event["data"]
response_id = response_message.id if response_message else None
generation_id = (
response_message.generation_id if response_message else None
)

non_streamed_chat_response = NonStreamedChatResponse(
text=data.get("text", ""),
response_id=response_id,
generation_id=generation_id,
chat_history=data.get("chat_history", []),
finish_reason=data.get("finish_reason", ""),
citations=data.get("citations", []),
search_queries=data.get("search_queries", []),
documents=data.get("documents", []),
search_results=data.get("search_results", []),
event_type=StreamEvent.NON_STREAMED_CHAT_RESPONSE,
conversation_id=conversation_id,
tool_calls=data.get("tool_calls", []),
)
for event in stream:
event = json.loads(event)
if event["event"] == StreamEvent.STREAM_END:
data = event["data"]
response_id = response_message.id if response_message else None
generation_id = response_message.generation_id if response_message else None

non_streamed_chat_response = NonStreamedChatResponse(
text=data.get("text", ""),
response_id=response_id,
generation_id=generation_id,
chat_history=data.get("chat_history", []),
finish_reason=data.get("finish_reason", ""),
citations=data.get("citations", []),
search_queries=data.get("search_queries", []),
documents=data.get("documents", []),
search_results=data.get("search_results", []),
event_type=StreamEvent.NON_STREAMED_CHAT_RESPONSE,
conversation_id=conversation_id,
tool_calls=data.get("tool_calls", []),
)

return non_streamed_chat_response

Expand Down Expand Up @@ -586,35 +583,32 @@ async def generate_chat_stream(
document_ids_to_document = {}

stream_event = None
async with AsyncGeneratorContextManager(
model_deployment_stream
) as model_deployment_stream_tmp:
async for event in model_deployment_stream_tmp:
(
stream_event,
stream_end_data,
response_message,
document_ids_to_document,
) = handle_stream_event(
event,
conversation_id,
stream_end_data,
response_message,
document_ids_to_document,
session=session,
should_store=should_store,
user_id=user_id,
next_message_position=kwargs.get("next_message_position", 0),
)
for event in model_deployment_stream:
(
stream_event,
stream_end_data,
response_message,
document_ids_to_document,
) = handle_stream_event(
event,
conversation_id,
stream_end_data,
response_message,
document_ids_to_document,
session=session,
should_store=should_store,
user_id=user_id,
next_message_position=kwargs.get("next_message_position", 0),
)

yield json.dumps(
jsonable_encoder(
ChatResponseEvent(
event=stream_event.event_type.value,
data=stream_event,
)
yield json.dumps(
jsonable_encoder(
ChatResponseEvent(
event=stream_event.event_type.value,
data=stream_event,
)
)
)

if should_store:
update_conversation_after_turn(
Expand Down
2 changes: 1 addition & 1 deletion src/backend/tools/python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class PythonInterpreter(BaseTool):
def is_available(cls) -> bool:
return cls.INTERPRETER_URL is not None

async def call(self, parameters: dict, **kwargs: Any):
def call(self, parameters: dict, **kwargs: Any):
if not self.INTERPRETER_URL:
raise Exception("Python Interpreter tool called while URL not set")

Expand Down
Loading

0 comments on commit ab89fe7

Please sign in to comment.