Skip to content
This repository has been archived by the owner on Nov 3, 2024. It is now read-only.

[HN-299/HN-305] feat: fetch all chats for specific user #20

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hive_agent_client/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .chat import send_chat_message, get_chat_history
from .chat import send_chat_message, get_chat_history, get_all_chats
64 changes: 60 additions & 4 deletions hive_agent_client/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ def get_log_level():


async def send_chat_message(
http_client: httpx.AsyncClient, base_url: str, user_id: str, session_id: str, content: str
http_client: httpx.AsyncClient,
base_url: str,
user_id: str,
session_id: str,
content: str,
) -> str:
"""
Sends a chat message to the Hive Agent API and returns the response.
Expand All @@ -41,7 +45,7 @@ async def send_chat_message(
payload = {
"user_id": user_id,
"session_id": session_id,
"chat_data": {"messages": [{"role": "user", "content": content}]}
"chat_data": {"messages": [{"role": "user", "content": content}]},
}

try:
Expand Down Expand Up @@ -94,7 +98,9 @@ async def get_chat_history(
response = await http_client.get(url, params=params)
response.raise_for_status()
chat_history = response.json()
logger.debug(f"Chat history for user {user_id} and session {session_id}: {chat_history}")
logger.debug(
f"Chat history for user {user_id} and session {session_id}: {chat_history}"
)
return chat_history
except httpx.HTTPStatusError as e:
logging.error(
Expand All @@ -104,7 +110,9 @@ async def get_chat_history(
f"HTTP error occurred when fetching chat history from the chat API: {e.response.status_code} - {e.response.text}"
)
except httpx.RequestError as e:
logging.error(f"Request error occurred when fetching chat history from {url}: {e}")
logging.error(
f"Request error occurred when fetching chat history from {url}: {e}"
)
raise Exception(
f"Request error occurred when fetching chat history from the chat API: {e}"
)
Expand All @@ -115,3 +123,51 @@ async def get_chat_history(
raise Exception(
f"An unexpected error occurred when fetching chat history from the chat API: {e}"
)


async def get_all_chats(
http_client: httpx.AsyncClient, base_url: str, user_id: str
) -> Dict[str, List[Dict]]:
"""
Retrieves all chat sessions for a specified user from the Hive Agent API.

:param http_client: An instance of httpx.AsyncClient to make HTTP requests.
:param base_url: The base URL of the Hive Agent API.
:param user_id: The user ID.
:return: All chat sessions as a dictionary with session IDs as keys and lists of messages as values.
:raises httpx.HTTPStatusError: If the request fails due to a network error or returns a 4xx/5xx response.
:raises Exception: For other types of errors.
"""

endpoint = "/all_chats"
url = f"{base_url}{endpoint}"
params = {"user_id": user_id}

try:
logging.debug(f"Fetching all chats from {url} with params: {params}")
response = await http_client.get(url, params=params)
response.raise_for_status()

all_chats = response.json()
logger.debug(f"All chats for user {user_id}: {all_chats}")

return all_chats
except httpx.HTTPStatusError as e:
logging.error(
f"HTTP error occurred when fetching all chats from {url}: {e.response.status_code} - {e.response.text}"
)
raise Exception(
f"HTTP error occurred when fetching all chats from the chat API: {e.response.status_code} - {e.response.text}"
)
except httpx.RequestError as e:
logging.error(f"Request error occurred when fetching all chats from {url}: {e}")
raise Exception(
f"Request error occurred when fetching all chats from the chat API: {e}"
)
except Exception as e:
logging.error(
f"An unexpected error occurred when fetching all chats from {url}: {e}"
)
raise Exception(
f"An unexpected error occurred when fetching all chats from the chat API: {e}"
)
28 changes: 24 additions & 4 deletions hive_agent_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from typing import Dict, List

from hive_agent_client.chat import send_chat_message, get_chat_history
from hive_agent_client.chat import send_chat_message, get_chat_history, get_all_chats
from hive_agent_client.database import (
create_table,
insert_data,
Expand Down Expand Up @@ -45,7 +45,9 @@ async def chat(self, user_id: str, session_id: str, content: str) -> str:
"""
try:
logger.debug(f"Sending message to chat endpoint: {content}")
return await send_chat_message(self.http_client, self.base_url, user_id, session_id, content)
return await send_chat_message(
self.http_client, self.base_url, user_id, session_id, content
)
except Exception as e:
logger.error(f"Failed to send chat message - {content}: {e}")
raise Exception(f"Failed to send chat message: {e}")
Expand All @@ -59,11 +61,29 @@ async def get_chat_history(self, user_id: str, session_id: str) -> List[Dict]:
:return: The chat history as a list of dictionaries.
"""
try:
return await get_chat_history(self.http_client, self.base_url, user_id, session_id)
return await get_chat_history(
self.http_client, self.base_url, user_id, session_id
)
except Exception as e:
logger.error(f"Failed to get chat history for user {user_id} and session {session_id}: {e}")
logger.error(
f"Failed to get chat history for user {user_id} and session {session_id}: {e}"
)
raise Exception(f"Failed to get chat history: {e}")

async def get_all_chats(self, user_id: str) -> Dict[str, List[Dict]]:
"""
Retrieve all chat sessions for a specified user.

:param user_id: The user ID.
:return: All chat sessions as a dictionary with session IDs as keys and lists of messages as values.
"""

try:
return await get_all_chats(self.http_client, self.base_url, user_id)
except Exception as e:
logger.error(f"Failed to get all chats for user {user_id}: {e}")
raise Exception(f"Failed to get all chats: {e}")

async def create_table(self, table_name: str, columns: dict) -> Dict:
"""
Create a new table in the database.
Expand Down
103 changes: 92 additions & 11 deletions tests/chat/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from unittest.mock import AsyncMock
from hive_agent_client.chat import send_chat_message, get_chat_history
from hive_agent_client.chat import send_chat_message, get_chat_history, get_all_chats


@pytest.mark.asyncio
Expand All @@ -18,15 +18,17 @@ async def test_send_chat_message_success():
session_id = "session123"
content = "Hello, how are you?"

result = await send_chat_message(mock_client, base_url, user_id, session_id, content)
result = await send_chat_message(
mock_client, base_url, user_id, session_id, content
)

assert result == "Hello, world!"
mock_client.post.assert_called_once_with(
"http://example.com/api/v1/chat",
json={
"user_id": user_id,
"session_id": session_id,
"chat_data": {"messages": [{"role": "user", "content": content}]}
"chat_data": {"messages": [{"role": "user", "content": content}]},
},
)

Expand Down Expand Up @@ -60,8 +62,8 @@ async def test_send_chat_message_http_error():
content = "Hello, how are you?"

with pytest.raises(
Exception,
match="HTTP error occurred when sending message to the chat API: 400 - Bad request",
Exception,
match="HTTP error occurred when sending message to the chat API: 400 - Bad request",
):
await send_chat_message(mock_client, base_url, user_id, session_id, content)

Expand All @@ -72,10 +74,20 @@ async def test_get_chat_history_success():
mock_response = AsyncMock(spec=httpx.Response)
mock_response.status_code = 200
expected_history = [
{"user_id": "user123", "session_id": "session123", "message": "Hello", "role": "user",
"timestamp": "2023-01-01T00:00:00Z"},
{"user_id": "user123", "session_id": "session123", "message": "Hi there", "role": "assistant",
"timestamp": "2023-01-01T00:00:01Z"}
{
"user_id": "user123",
"session_id": "session123",
"message": "Hello",
"role": "user",
"timestamp": "2023-01-01T00:00:00Z",
},
{
"user_id": "user123",
"session_id": "session123",
"message": "Hi there",
"role": "assistant",
"timestamp": "2023-01-01T00:00:01Z",
},
]
mock_response.json.return_value = expected_history
mock_client.get.return_value = mock_response
Expand Down Expand Up @@ -109,7 +121,76 @@ async def test_get_chat_history_failure():
session_id = "session123"

with pytest.raises(
Exception,
match="HTTP error occurred when fetching chat history from the chat API: 400 - Bad request",
Exception,
match="HTTP error occurred when fetching chat history from the chat API: 400 - Bad request",
):
await get_chat_history(mock_client, base_url, user_id, session_id)


@pytest.mark.asyncio
async def test_get_all_chats_success():
mock_client = AsyncMock(spec=httpx.AsyncClient)
mock_response = AsyncMock(spec=httpx.Response)
mock_response.status_code = 200

expected_all_chats = {
"session1": [
{
"message": "Hello in session1",
"role": "USER",
"timestamp": "2023-01-01T00:00:00Z",
},
{
"message": "Response in session1",
"role": "ASSISTANT",
"timestamp": "2023-01-01T00:00:01Z",
},
],
"session2": [
{
"message": "Hello in session2",
"role": "USER",
"timestamp": "2023-01-01T00:00:02Z",
},
{
"message": "Response in session2",
"role": "ASSISTANT",
"timestamp": "2023-01-01T00:00:03Z",
},
],
}

mock_response.json.return_value = expected_all_chats
mock_client.get.return_value = mock_response

base_url = "http://example.com/api/v1"
user_id = "user123"

result = await get_all_chats(mock_client, base_url, user_id)
assert result == expected_all_chats

mock_client.get.assert_called_once_with(
f"http://example.com/api/v1/all_chats",
params={"user_id": user_id},
)


@pytest.mark.asyncio
async def test_get_all_chats_failure():
mock_client = AsyncMock(spec=httpx.AsyncClient)
mock_response = AsyncMock(spec=httpx.Response)
mock_response.status_code = 400
mock_response.text = "Bad request"
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
message="Bad request", request=mock_response.request, response=mock_response
)
mock_client.get.return_value = mock_response

base_url = "http://example.com/api/v1"
user_id = "user123"

with pytest.raises(
Exception,
match="HTTP error occurred when fetching all chats from the chat API: 400 - Bad request",
):
await get_all_chats(mock_client, base_url, user_id)
81 changes: 78 additions & 3 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,20 @@ async def test_get_chat_history_success():
user_id = "user123"
session_id = "session123"
expected_history = [
{"user_id": user_id, "session_id": session_id, "message": "Hello", "role": "user", "timestamp": "2023-01-01T00:00:00Z"},
{"user_id": user_id, "session_id": session_id, "message": "Hi there", "role": "assistant", "timestamp": "2023-01-01T00:00:01Z"}
{
"user_id": user_id,
"session_id": session_id,
"message": "Hello",
"role": "user",
"timestamp": "2023-01-01T00:00:00Z",
},
{
"user_id": user_id,
"session_id": session_id,
"message": "Hi there",
"role": "assistant",
"timestamp": "2023-01-01T00:00:01Z",
},
]

with respx.mock() as mock:
Expand All @@ -90,7 +102,70 @@ async def test_get_chat_history_failure():
client = HiveAgentClient(base_url, version)
with pytest.raises(Exception) as excinfo:
await client.get_chat_history(user_id, session_id)
assert "HTTP error occurred when fetching chat history from the chat API: 400" in str(excinfo.value)
assert (
"HTTP error occurred when fetching chat history from the chat API: 400"
in str(excinfo.value)
)


@pytest.mark.asyncio
async def test_get_all_chats_success():
user_id = "user123"

expected_all_chats = {
"session1": [
{
"message": "Hello in session1",
"role": "USER",
"timestamp": "2023-01-01T00:00:00Z",
},
{
"message": "Response in session1",
"role": "ASSISTANT",
"timestamp": "2023-01-01T00:00:01Z",
},
],
"session2": [
{
"message": "Hello in session2",
"role": "USER",
"timestamp": "2023-01-01T00:00:02Z",
},
{
"message": "Response in session2",
"role": "ASSISTANT",
"timestamp": "2023-01-01T00:00:03Z",
},
],
}

with respx.mock() as mock:
mock.get(f"{base_url}/v1/all_chats").mock(
return_value=httpx.Response(200, json=expected_all_chats)
)

client = HiveAgentClient(base_url, version)

all_chats = await client.get_all_chats(user_id)
assert all_chats == expected_all_chats


@pytest.mark.asyncio
async def test_get_all_chats_failure():
user_id = "user123"

with respx.mock() as mock:
mock.get(f"{base_url}/v1/all_chats").mock(return_value=httpx.Response(400))

client = HiveAgentClient(base_url, version)

with pytest.raises(Exception) as excinfo:
await client.get_all_chats(user_id)

assert (
"HTTP error occurred when fetching all chats from the chat API: 400"
in str(excinfo.value)
)


@pytest.mark.asyncio
Expand Down
Loading
Loading