From d33ad21ec279318e338d88dbdead67d27461202b Mon Sep 17 00:00:00 2001 From: Tomisin Jenrola Date: Tue, 2 Jul 2024 21:32:58 -0400 Subject: [PATCH] [HN-299/HN-305] feat: fetch all chats for specific user --- hive_agent_client/chat/__init__.py | 2 +- hive_agent_client/chat/chat.py | 64 ++++++++++++++++-- hive_agent_client/client.py | 28 ++++++-- tests/chat/test_chat.py | 103 ++++++++++++++++++++++++++--- tests/test_client.py | 81 ++++++++++++++++++++++- tutorial.md | 14 ++++ 6 files changed, 269 insertions(+), 23 deletions(-) diff --git a/hive_agent_client/chat/__init__.py b/hive_agent_client/chat/__init__.py index 92f99ca..b9e029f 100644 --- a/hive_agent_client/chat/__init__.py +++ b/hive_agent_client/chat/__init__.py @@ -1 +1 @@ -from .chat import send_chat_message, get_chat_history +from .chat import send_chat_message, get_chat_history, get_all_chats diff --git a/hive_agent_client/chat/chat.py b/hive_agent_client/chat/chat.py index 89d1b8e..d9d19db 100644 --- a/hive_agent_client/chat/chat.py +++ b/hive_agent_client/chat/chat.py @@ -18,7 +18,11 @@ def get_log_level(): async def send_chat_message( - http_client: httpx.AsyncClient, base_url: str, user_id: str, session_id: str, content: str + http_client: httpx.AsyncClient, + base_url: str, + user_id: str, + session_id: str, + content: str, ) -> str: """ Sends a chat message to the Hive Agent API and returns the response. @@ -41,7 +45,7 @@ async def send_chat_message( payload = { "user_id": user_id, "session_id": session_id, - "chat_data": {"messages": [{"role": "user", "content": content}]} + "chat_data": {"messages": [{"role": "user", "content": content}]}, } try: @@ -94,7 +98,9 @@ async def get_chat_history( response = await http_client.get(url, params=params) response.raise_for_status() chat_history = response.json() - logger.debug(f"Chat history for user {user_id} and session {session_id}: {chat_history}") + logger.debug( + f"Chat history for user {user_id} and session {session_id}: {chat_history}" + ) return chat_history except httpx.HTTPStatusError as e: logging.error( @@ -104,7 +110,9 @@ async def get_chat_history( f"HTTP error occurred when fetching chat history from the chat API: {e.response.status_code} - {e.response.text}" ) except httpx.RequestError as e: - logging.error(f"Request error occurred when fetching chat history from {url}: {e}") + logging.error( + f"Request error occurred when fetching chat history from {url}: {e}" + ) raise Exception( f"Request error occurred when fetching chat history from the chat API: {e}" ) @@ -115,3 +123,51 @@ async def get_chat_history( raise Exception( f"An unexpected error occurred when fetching chat history from the chat API: {e}" ) + + +async def get_all_chats( + http_client: httpx.AsyncClient, base_url: str, user_id: str +) -> Dict[str, List[Dict]]: + """ + Retrieves all chat sessions for a specified user from the Hive Agent API. + + :param http_client: An instance of httpx.AsyncClient to make HTTP requests. + :param base_url: The base URL of the Hive Agent API. + :param user_id: The user ID. + :return: All chat sessions as a dictionary with session IDs as keys and lists of messages as values. + :raises httpx.HTTPStatusError: If the request fails due to a network error or returns a 4xx/5xx response. + :raises Exception: For other types of errors. + """ + + endpoint = "/all_chats" + url = f"{base_url}{endpoint}" + params = {"user_id": user_id} + + try: + logging.debug(f"Fetching all chats from {url} with params: {params}") + response = await http_client.get(url, params=params) + response.raise_for_status() + + all_chats = response.json() + logger.debug(f"All chats for user {user_id}: {all_chats}") + + return all_chats + except httpx.HTTPStatusError as e: + logging.error( + f"HTTP error occurred when fetching all chats from {url}: {e.response.status_code} - {e.response.text}" + ) + raise Exception( + f"HTTP error occurred when fetching all chats from the chat API: {e.response.status_code} - {e.response.text}" + ) + except httpx.RequestError as e: + logging.error(f"Request error occurred when fetching all chats from {url}: {e}") + raise Exception( + f"Request error occurred when fetching all chats from the chat API: {e}" + ) + except Exception as e: + logging.error( + f"An unexpected error occurred when fetching all chats from {url}: {e}" + ) + raise Exception( + f"An unexpected error occurred when fetching all chats from the chat API: {e}" + ) diff --git a/hive_agent_client/client.py b/hive_agent_client/client.py index c1537bd..8474b18 100644 --- a/hive_agent_client/client.py +++ b/hive_agent_client/client.py @@ -2,7 +2,7 @@ import logging from typing import Dict, List -from hive_agent_client.chat import send_chat_message, get_chat_history +from hive_agent_client.chat import send_chat_message, get_chat_history, get_all_chats from hive_agent_client.database import ( create_table, insert_data, @@ -45,7 +45,9 @@ async def chat(self, user_id: str, session_id: str, content: str) -> str: """ try: logger.debug(f"Sending message to chat endpoint: {content}") - return await send_chat_message(self.http_client, self.base_url, user_id, session_id, content) + return await send_chat_message( + self.http_client, self.base_url, user_id, session_id, content + ) except Exception as e: logger.error(f"Failed to send chat message - {content}: {e}") raise Exception(f"Failed to send chat message: {e}") @@ -59,11 +61,29 @@ async def get_chat_history(self, user_id: str, session_id: str) -> List[Dict]: :return: The chat history as a list of dictionaries. """ try: - return await get_chat_history(self.http_client, self.base_url, user_id, session_id) + return await get_chat_history( + self.http_client, self.base_url, user_id, session_id + ) except Exception as e: - logger.error(f"Failed to get chat history for user {user_id} and session {session_id}: {e}") + logger.error( + f"Failed to get chat history for user {user_id} and session {session_id}: {e}" + ) raise Exception(f"Failed to get chat history: {e}") + async def get_all_chats(self, user_id: str) -> Dict[str, List[Dict]]: + """ + Retrieve all chat sessions for a specified user. + + :param user_id: The user ID. + :return: All chat sessions as a dictionary with session IDs as keys and lists of messages as values. + """ + + try: + return await get_all_chats(self.http_client, self.base_url, user_id) + except Exception as e: + logger.error(f"Failed to get all chats for user {user_id}: {e}") + raise Exception(f"Failed to get all chats: {e}") + async def create_table(self, table_name: str, columns: dict) -> Dict: """ Create a new table in the database. diff --git a/tests/chat/test_chat.py b/tests/chat/test_chat.py index a3524f1..141c7ce 100644 --- a/tests/chat/test_chat.py +++ b/tests/chat/test_chat.py @@ -2,7 +2,7 @@ import pytest from unittest.mock import AsyncMock -from hive_agent_client.chat import send_chat_message, get_chat_history +from hive_agent_client.chat import send_chat_message, get_chat_history, get_all_chats @pytest.mark.asyncio @@ -18,7 +18,9 @@ async def test_send_chat_message_success(): session_id = "session123" content = "Hello, how are you?" - result = await send_chat_message(mock_client, base_url, user_id, session_id, content) + result = await send_chat_message( + mock_client, base_url, user_id, session_id, content + ) assert result == "Hello, world!" mock_client.post.assert_called_once_with( @@ -26,7 +28,7 @@ async def test_send_chat_message_success(): json={ "user_id": user_id, "session_id": session_id, - "chat_data": {"messages": [{"role": "user", "content": content}]} + "chat_data": {"messages": [{"role": "user", "content": content}]}, }, ) @@ -60,8 +62,8 @@ async def test_send_chat_message_http_error(): content = "Hello, how are you?" with pytest.raises( - Exception, - match="HTTP error occurred when sending message to the chat API: 400 - Bad request", + Exception, + match="HTTP error occurred when sending message to the chat API: 400 - Bad request", ): await send_chat_message(mock_client, base_url, user_id, session_id, content) @@ -72,10 +74,20 @@ async def test_get_chat_history_success(): mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 expected_history = [ - {"user_id": "user123", "session_id": "session123", "message": "Hello", "role": "user", - "timestamp": "2023-01-01T00:00:00Z"}, - {"user_id": "user123", "session_id": "session123", "message": "Hi there", "role": "assistant", - "timestamp": "2023-01-01T00:00:01Z"} + { + "user_id": "user123", + "session_id": "session123", + "message": "Hello", + "role": "user", + "timestamp": "2023-01-01T00:00:00Z", + }, + { + "user_id": "user123", + "session_id": "session123", + "message": "Hi there", + "role": "assistant", + "timestamp": "2023-01-01T00:00:01Z", + }, ] mock_response.json.return_value = expected_history mock_client.get.return_value = mock_response @@ -109,7 +121,76 @@ async def test_get_chat_history_failure(): session_id = "session123" with pytest.raises( - Exception, - match="HTTP error occurred when fetching chat history from the chat API: 400 - Bad request", + Exception, + match="HTTP error occurred when fetching chat history from the chat API: 400 - Bad request", ): await get_chat_history(mock_client, base_url, user_id, session_id) + + +@pytest.mark.asyncio +async def test_get_all_chats_success(): + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + + expected_all_chats = { + "session1": [ + { + "message": "Hello in session1", + "role": "USER", + "timestamp": "2023-01-01T00:00:00Z", + }, + { + "message": "Response in session1", + "role": "ASSISTANT", + "timestamp": "2023-01-01T00:00:01Z", + }, + ], + "session2": [ + { + "message": "Hello in session2", + "role": "USER", + "timestamp": "2023-01-01T00:00:02Z", + }, + { + "message": "Response in session2", + "role": "ASSISTANT", + "timestamp": "2023-01-01T00:00:03Z", + }, + ], + } + + mock_response.json.return_value = expected_all_chats + mock_client.get.return_value = mock_response + + base_url = "http://example.com/api/v1" + user_id = "user123" + + result = await get_all_chats(mock_client, base_url, user_id) + assert result == expected_all_chats + + mock_client.get.assert_called_once_with( + f"http://example.com/api/v1/all_chats", + params={"user_id": user_id}, + ) + + +@pytest.mark.asyncio +async def test_get_all_chats_failure(): + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 400 + mock_response.text = "Bad request" + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + message="Bad request", request=mock_response.request, response=mock_response + ) + mock_client.get.return_value = mock_response + + base_url = "http://example.com/api/v1" + user_id = "user123" + + with pytest.raises( + Exception, + match="HTTP error occurred when fetching all chats from the chat API: 400 - Bad request", + ): + await get_all_chats(mock_client, base_url, user_id) diff --git a/tests/test_client.py b/tests/test_client.py index 083fddd..e149a7a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -65,8 +65,20 @@ async def test_get_chat_history_success(): user_id = "user123" session_id = "session123" expected_history = [ - {"user_id": user_id, "session_id": session_id, "message": "Hello", "role": "user", "timestamp": "2023-01-01T00:00:00Z"}, - {"user_id": user_id, "session_id": session_id, "message": "Hi there", "role": "assistant", "timestamp": "2023-01-01T00:00:01Z"} + { + "user_id": user_id, + "session_id": session_id, + "message": "Hello", + "role": "user", + "timestamp": "2023-01-01T00:00:00Z", + }, + { + "user_id": user_id, + "session_id": session_id, + "message": "Hi there", + "role": "assistant", + "timestamp": "2023-01-01T00:00:01Z", + }, ] with respx.mock() as mock: @@ -90,7 +102,70 @@ async def test_get_chat_history_failure(): client = HiveAgentClient(base_url, version) with pytest.raises(Exception) as excinfo: await client.get_chat_history(user_id, session_id) - assert "HTTP error occurred when fetching chat history from the chat API: 400" in str(excinfo.value) + assert ( + "HTTP error occurred when fetching chat history from the chat API: 400" + in str(excinfo.value) + ) + + +@pytest.mark.asyncio +async def test_get_all_chats_success(): + user_id = "user123" + + expected_all_chats = { + "session1": [ + { + "message": "Hello in session1", + "role": "USER", + "timestamp": "2023-01-01T00:00:00Z", + }, + { + "message": "Response in session1", + "role": "ASSISTANT", + "timestamp": "2023-01-01T00:00:01Z", + }, + ], + "session2": [ + { + "message": "Hello in session2", + "role": "USER", + "timestamp": "2023-01-01T00:00:02Z", + }, + { + "message": "Response in session2", + "role": "ASSISTANT", + "timestamp": "2023-01-01T00:00:03Z", + }, + ], + } + + with respx.mock() as mock: + mock.get(f"{base_url}/v1/all_chats").mock( + return_value=httpx.Response(200, json=expected_all_chats) + ) + + client = HiveAgentClient(base_url, version) + + all_chats = await client.get_all_chats(user_id) + assert all_chats == expected_all_chats + + +@pytest.mark.asyncio +async def test_get_all_chats_failure(): + user_id = "user123" + + with respx.mock() as mock: + mock.get(f"{base_url}/v1/all_chats").mock(return_value=httpx.Response(400)) + + client = HiveAgentClient(base_url, version) + + with pytest.raises(Exception) as excinfo: + await client.get_all_chats(user_id) + + assert ( + "HTTP error occurred when fetching all chats from the chat API: 400" + in str(excinfo.value) + ) @pytest.mark.asyncio diff --git a/tutorial.md b/tutorial.md index eb1b7b6..cae2743 100644 --- a/tutorial.md +++ b/tutorial.md @@ -50,6 +50,19 @@ async def fetch_chat_history(user_id, session_id): print("Error:", e) ``` +## Getting All Chats + +To fetch all chat sessions for a user, you can use the `get_all_chats` method: + +```python +async def fetch_all_chats(user_id): + try: + all_chats = await client.get_all_chats(user_id=user_id) + print("All chats:", all_chats) + except Exception as e: + print("Error:", e) +``` + ## Creating a Table Create a new table in the database: @@ -200,6 +213,7 @@ import asyncio async def main(): await send_message("user123", "session123", "Hello, world!") await fetch_chat_history("user123", "session123") + await fetch_all_chats("user123") await create_new_table("my_table", {"id": "Integer", "name": "String"}) await insert_new_data("my_table", {"name": "Test"})