Skip to content

Commit

Permalink
Merge pull request #932 from lucagobbi/fix_message_endpoint
Browse files Browse the repository at this point in the history
Run stray in threadpool to allow tool execution on http message endpoint
  • Loading branch information
pieroit authored Oct 6, 2024
2 parents 78455c0 + a494d20 commit 6b251cb
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
17 changes: 11 additions & 6 deletions core/cat/looking_glass/stray_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,17 +469,22 @@ async def __call__(self, message_dict):

return final_output

def run(self, user_message_json):
def run(self, user_message_json, return_message=False):
try:
cat_message = self.loop.run_until_complete(self.__call__(user_message_json))
# send message back to client
self.send_chat_message(cat_message)
if return_message:
# return the message for HTTP usage
return cat_message
else:
# send message back to client via WS
self.send_chat_message(cat_message)
except Exception as e:
# Log any unexpected errors
log.error(e)
traceback.print_exc()
# Send error as websocket message
self.send_error(e)
if return_message:
return {"error": str(e)}
else:
self.send_error(e)

def classify(
self, sentence: str, labels: List[str] | Dict[str, List[str]]
Expand Down
4 changes: 3 additions & 1 deletion core/cat/routes/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi import APIRouter, Depends, Body
from fastapi.concurrency import run_in_threadpool
from typing import Dict
import tomli
from cat.auth.permissions import AuthPermission, AuthResource
Expand Down Expand Up @@ -27,5 +28,6 @@ async def message_with_cat(
stray=Depends(HTTPAuth(AuthResource.CONVERSATION, AuthPermission.WRITE)),
) -> Dict:
"""Get a response from the Cat"""
answer = await stray({"user_id": stray.user_id, **payload})
user_message_json = {"user_id": stray.user_id, **payload}
answer = await run_in_threadpool(stray.run, user_message_json, True)
return answer
2 changes: 1 addition & 1 deletion core/cat/routes/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async def receive_message(websocket: WebSocket, stray: StrayCat):
user_message["user_id"] = stray.user_id

# Run the `stray` object's method in a threadpool since it might be a CPU-bound operation.
await run_in_threadpool(stray.run, user_message)
await run_in_threadpool(stray.run, user_message, return_message=False)


@router.websocket("/ws")
Expand Down

0 comments on commit 6b251cb

Please sign in to comment.