diff --git a/hive_agent_client/chat/__init__.py b/hive_agent_client/chat/__init__.py index 379dce8..af4098c 100644 --- a/hive_agent_client/chat/__init__.py +++ b/hive_agent_client/chat/__init__.py @@ -2,5 +2,4 @@ send_chat_message, get_chat_history, get_all_chats, - send_chat_media ) diff --git a/hive_agent_client/chat/chat.py b/hive_agent_client/chat/chat.py index 57e67f5..5a624b8 100644 --- a/hive_agent_client/chat/chat.py +++ b/hive_agent_client/chat/chat.py @@ -1,9 +1,13 @@ import httpx +import io +import json import logging +import mimetypes import os import sys -from typing import List, Dict +from typing import List, Dict, Union +from fastapi import UploadFile def get_log_level(): HIVE_AGENT_LOG_LEVEL = os.getenv("HIVE_AGENT_LOG_LEVEL", "INFO").upper() @@ -18,11 +22,12 @@ def get_log_level(): async def send_chat_message( - http_client: httpx.AsyncClient, - base_url: str, - user_id: str, - session_id: str, - content: str, + http_client: httpx.AsyncClient, + base_url: str, + user_id: str, + session_id: str, + content: str, + files: List[Union[UploadFile, str]] = None # Can be either UploadFile or file path ) -> str: """ Sends a chat message to the Hive Agent API and returns the response. @@ -32,6 +37,7 @@ async def send_chat_message( :param user_id: The user ID. :param session_id: The session ID. :param content: The content of the message to be sent. + :param files: List of UploadFile objects or file paths to be uploaded. :return: The response text from the API. :raises ValueError: If the content is empty. :raises httpx.HTTPStatusError: If the request fails due to a network error or returns a 4xx/5xx response. @@ -42,15 +48,46 @@ async def send_chat_message( endpoint = "/chat" url = f"{base_url}{endpoint}" - payload = { + + chat_data = json.dumps({ + "messages": [ + { + "role": "user", + "content": content + } + ] + }) + + form_data = { "user_id": user_id, "session_id": session_id, - "chat_data": {"messages": [{"role": "user", "content": content}]}, + "chat_data": chat_data, # Send as a JSON string } + files_to_send = [] + + if files: + for file in files: + if isinstance(file, UploadFile): + # Handle UploadFile (received from an HTTP request) + files_to_send.append( + ("files", (file.filename, file.file, file.content_type)) + ) + elif isinstance(file, str): + # Handle file paths + file_name = os.path.basename(file) + content_type, _ = mimetypes.guess_type(file) + files_to_send.append( + ("files", (file_name, open(file, "rb"), content_type)) + ) + try: - logging.debug(f"Sending chat message to {url}: {content}") - response = await http_client.post(url, json=payload) + logging.debug(f"Sending chat message to {url}") + response = await http_client.post( + url, + data=form_data, + files=files_to_send # Attach files if any + ) response.raise_for_status() logger.debug(f"Response from chat message {content}: {response.text}") return response.text @@ -73,6 +110,11 @@ async def send_chat_message( raise Exception( f"An unexpected error occurred when sending message to the chat API: {e}" ) + finally: + # Close any files opened from file paths + for file in files_to_send: + if isinstance(file[1][1], io.IOBase): # Check if it's a file object + file[1][1].close() async def get_chat_history( @@ -171,70 +213,3 @@ async def get_all_chats( raise Exception( f"An unexpected error occurred when fetching all chats from the chat API: {e}" ) - - -async def send_chat_media( - http_client: httpx.AsyncClient, - base_url: str, - user_id: str, - session_id: str, - chat_data: str, - files: List[str], -) -> str: - """ - Sends a chat message with associated media files to the Hive Agent API and returns the response. - - :param http_client: An instance of httpx.AsyncClient to make HTTP requests. - :param base_url: The base URL of the Hive Agent API. - :param user_id: The user ID. - :param session_id: The session ID. - :param chat_data: The chat data in JSON format as a string. - :param files: A list of file paths to be uploaded. - :return: The response text from the API. - :raises ValueError: If the chat_data or files list is empty. - :raises httpx.HTTPStatusError: If the request fails due to a network error or returns a 4xx/5xx response. - :raises Exception: For other types of errors. - """ - if not chat_data.strip(): - raise ValueError("Chat data must not be empty") - if not files: - raise ValueError("Files list must not be empty") - - endpoint = "/chat_media" - url = f"{base_url}{endpoint}" - - files_data = [('files', open(file_path, 'rb')) for file_path in files] - data = { - 'user_id': user_id, - 'session_id': session_id, - 'chat_data': chat_data, - } - - try: - logging.debug(f"Sending chat media to {url} with files: {files}") - response = await http_client.post(url, data=data, files=files_data) - response.raise_for_status() - logger.debug(f"Response from chat media: {response.text}") - return response.text - except httpx.HTTPStatusError as e: - logging.error( - f"HTTP error occurred when sending chat media to {url}: {e.response.status_code} - {e.response.text}" - ) - raise Exception( - f"HTTP error occurred when sending chat media to the chat API: {e.response.status_code} - {e.response.text}" - ) - except httpx.RequestError as e: - logging.error(f"Request error occurred when sending chat media to {url}: {e}") - raise Exception( - f"Request error occurred when sending chat media to the chat API: {e}" - ) - except Exception as e: - logging.error( - f"An unexpected error occurred when sending chat media to {url}: {e}" - ) - raise Exception( - f"An unexpected error occurred when sending chat media to the chat API: {e}" - ) - finally: - for _, file_handle in files_data: - file_handle.close() diff --git a/hive_agent_client/client.py b/hive_agent_client/client.py index a764152..24e5770 100644 --- a/hive_agent_client/client.py +++ b/hive_agent_client/client.py @@ -1,13 +1,14 @@ import httpx import logging -from typing import Dict, List, Any +from typing import Dict, List, Any, Union + +from fastapi import UploadFile from hive_agent_client.chat import ( send_chat_message, get_chat_history, get_all_chats, - send_chat_media ) from hive_agent_client.database import ( create_table, @@ -41,24 +42,31 @@ def __init__(self, base_url: str, version: str = "v1", timeout: float = 30.0): self.base_url = f"{base_url}/{version}" self.http_client = httpx.AsyncClient(timeout=timeout) - async def chat(self, user_id: str, session_id: str, content: str) -> str: + async def chat(self, user_id: str, session_id: str, content: str, files: List[Union[UploadFile, str]] = None) -> str: """ - Send a message to the chat endpoint. + Send a message to the chat endpoint with optional file attachments. :param user_id: The user ID. :param session_id: The session ID. :param content: The content of the message to send. + :param files: Optional list of file paths or UploadFile objects to send. :return: The response from the chat API as a string. """ try: logger.debug(f"Sending message to chat endpoint: {content}") return await send_chat_message( - self.http_client, self.base_url, user_id, session_id, content + self.http_client, self.base_url, user_id, session_id, content, files ) except Exception as e: logger.error(f"Failed to send chat message - {content}: {e}") raise Exception(f"Failed to send chat message: {e}") + async def close(self): + """ + Close the HTTP client session. + """ + await self.http_client.aclose() + async def get_chat_history(self, user_id: str, session_id: str) -> List[Dict]: """ Retrieve the chat history for a specified user and session. @@ -91,31 +99,6 @@ async def get_all_chats(self, user_id: str) -> Dict[str, List[Dict]]: logger.error(f"Failed to get all chats for user {user_id}: {e}") raise Exception(f"Failed to get all chats: {e}") - async def chat_media( - self, - user_id: str, - session_id: str, - chat_data: str, - files: List[str], - ) -> str: - """ - Send a chat message with associated media files to the chat_media endpoint. - - :param user_id: The user ID. - :param session_id: The session ID. - :param chat_data: The chat data in JSON format as a string. - :param files: A list of file paths to be uploaded. - :return: The response from the chat_media API as a string. - """ - try: - logger.debug(f"Sending chat media to chat_media endpoint with files: {files}") - return await send_chat_media( - self.http_client, self.base_url, user_id, session_id, chat_data, files - ) - except Exception as e: - logger.error(f"Failed to send chat media - files: {files}, error: {e}") - raise Exception(f"Failed to send chat media: {e}") - async def create_table(self, table_name: str, columns: dict) -> Dict: """ Create a new table in the database. diff --git a/pyproject.toml b/pyproject.toml index 95998ec..6a01b18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,8 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.8" httpx = "^0.27.0" +fastapi = "0.115.0" +pillow = "10.4.0" [tool.poetry.dev-dependencies] python = "^3.8" diff --git a/requirements.txt b/requirements.txt index 301b0a3..461591c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,3 @@ httpx==0.27.0 +fastapi==0.115.0 +pillow==10.4.0 diff --git a/setup.py b/setup.py index 8d60e5f..0e387c5 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name="hive-agent-client", version="0.0.1", packages=find_packages(include=["hive_agent_client", "hive_agent_client.*"]), - install_requires=["httpx==0.27.0"], + install_requires=["httpx==0.27.0", "fastapi==0.115.0", "pillow==10.4.0"], description="A client library for sending messages to a Hive Agent", long_description=open("README.md").read(), long_description_content_type="text/markdown", diff --git a/tests/chat/test_chat.py b/tests/chat/test_chat.py index 8c80278..7427349 100644 --- a/tests/chat/test_chat.py +++ b/tests/chat/test_chat.py @@ -1,5 +1,6 @@ import httpx import os +import json import pytest from PIL import Image @@ -9,10 +10,32 @@ send_chat_message, get_chat_history, get_all_chats, - send_chat_media ) +@pytest.fixture +def temp_image_files(): + file_paths = ["test1.png", "test2.png"] + for file_path in file_paths: + # Create a simple image file + image = Image.new('RGB', (60, 30), color=(73, 109, 137)) + image.save(file_path) + yield file_paths + for file_path in file_paths: + os.remove(file_path) + + +@pytest.fixture +def temp_files(): + file_paths = ["test1.txt", "test2.txt"] + for file_path in file_paths: + with open(file_path, "w") as f: + f.write("test content") + yield file_paths + for file_path in file_paths: + os.remove(file_path) + + @pytest.mark.asyncio async def test_send_chat_message_success(): mock_client = AsyncMock(spec=httpx.AsyncClient) @@ -33,14 +56,43 @@ async def test_send_chat_message_success(): assert result == "Hello, world!" mock_client.post.assert_called_once_with( "http://example.com/api/v1/chat", - json={ + data={ "user_id": user_id, "session_id": session_id, - "chat_data": {"messages": [{"role": "user", "content": content}]}, + "chat_data": json.dumps({ + "messages": [{"role": "user", "content": content}] + }), }, + files=[] ) +@pytest.mark.asyncio +async def test_send_chat_message_with_files(temp_image_files): + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.text = "Files uploaded!" + mock_client.post.return_value = mock_response + + base_url = "http://example.com/api/v1" + user_id = "user123" + session_id = "session123" + content = "Check out these images" + + # Test with local file paths + result = await send_chat_message( + mock_client, base_url, user_id, session_id, content, files=temp_image_files + ) + + assert result == "Files uploaded!" + mock_client.post.assert_called_once() + + # Ensure the files are being uploaded correctly + assert mock_client.post.call_args[1]['files'][0][0] == 'files' + assert mock_client.post.call_args[1]['files'][1][0] == 'files' + + @pytest.mark.asyncio async def test_send_chat_message_empty_content(): mock_client = AsyncMock(spec=httpx.AsyncClient) @@ -70,8 +122,8 @@ async def test_send_chat_message_http_error(): content = "Hello, how are you?" with pytest.raises( - Exception, - match="HTTP error occurred when sending message to the chat API: 400 - Bad request", + Exception, + match="HTTP error occurred when sending message to the chat API: 400 - Bad request", ): await send_chat_message(mock_client, base_url, user_id, session_id, content) @@ -129,8 +181,8 @@ async def test_get_chat_history_failure(): session_id = "session123" with pytest.raises( - Exception, - match="HTTP error occurred when fetching chat history from the chat API: 400 - Bad request", + Exception, + match="HTTP error occurred when fetching chat history from the chat API: 400 - Bad request", ): await get_chat_history(mock_client, base_url, user_id, session_id) @@ -198,113 +250,19 @@ async def test_get_all_chats_failure(): user_id = "user123" with pytest.raises( - Exception, - match="HTTP error occurred when fetching all chats from the chat API: 400 - Bad request", + Exception, + match="HTTP error occurred when fetching all chats from the chat API: 400 - Bad request", ): await get_all_chats(mock_client, base_url, user_id) + @pytest.fixture def temp_image_files(): file_paths = ["test1.png", "test2.png"] for file_path in file_paths: # Create a simple image file - image = Image.new('RGB', (60, 30), color = (73, 109, 137)) + image = Image.new('RGB', (60, 30), color=(73, 109, 137)) image.save(file_path) yield file_paths for file_path in file_paths: os.remove(file_path) - - -@pytest.mark.asyncio -async def test_send_chat_media_success(temp_image_files): - mock_client = AsyncMock(spec=httpx.AsyncClient) - mock_response = AsyncMock(spec=httpx.Response) - mock_response.status_code = 200 - mock_response.text = "Media uploaded and response generated" - mock_client.post.return_value = mock_response - - base_url = "http://example.com/api/v1" - user_id = "user123" - session_id = "session123" - chat_data = '{"messages": [{"role": "user", "content": "Here is a file"}]}' - - result = await send_chat_media( - mock_client, base_url, user_id, session_id, chat_data, temp_image_files - ) - - assert result == "Media uploaded and response generated" - mock_client.post.assert_called_once() - assert mock_client.post.call_args[1]["data"] == { - "user_id": user_id, - "session_id": session_id, - "chat_data": chat_data, - } - # Verifying that the files were sent correctly in the POST request - assert len(mock_client.post.call_args[1]["files"]) == len(temp_image_files) - - -@pytest.mark.asyncio -async def test_send_chat_media_empty_chat_data(temp_image_files): - mock_client = AsyncMock(spec=httpx.AsyncClient) - base_url = "http://example.com/api/v1" - user_id = "user123" - session_id = "session123" - chat_data = "" - - with pytest.raises(ValueError, match="Chat data must not be empty"): - await send_chat_media(mock_client, base_url, user_id, session_id, chat_data, temp_image_files) - - -@pytest.mark.asyncio -async def test_send_chat_media_empty_files(): - mock_client = AsyncMock(spec=httpx.AsyncClient) - base_url = "http://example.com/api/v1" - user_id = "user123" - session_id = "session123" - chat_data = '{"messages": [{"role": "user", "content": "Here is a file"}]}' - files = [] - - with pytest.raises(ValueError, match="Files list must not be empty"): - await send_chat_media(mock_client, base_url, user_id, session_id, chat_data, files) - - -@pytest.mark.asyncio -async def test_send_chat_media_http_error(temp_image_files): - mock_client = AsyncMock(spec=httpx.AsyncClient) - mock_response = AsyncMock(spec=httpx.Response) - mock_response.status_code = 400 - mock_response.text = "Bad request" - mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( - message="Bad request", request=mock_response.request, response=mock_response - ) - mock_client.post.return_value = mock_response - - base_url = "http://example.com/api/v1" - user_id = "user123" - session_id = "session123" - chat_data = '{"messages": [{"role": "user", "content": "Here is a file"}]}' - - with pytest.raises( - Exception, - match="HTTP error occurred when sending chat media to the chat API: 400 - Bad request", - ): - await send_chat_media(mock_client, base_url, user_id, session_id, chat_data, temp_image_files) - - -@pytest.mark.asyncio -async def test_send_chat_media_request_error(temp_image_files): - mock_client = AsyncMock(spec=httpx.AsyncClient) - base_url = "http://example.com/api/v1" - user_id = "user123" - session_id = "session123" - chat_data = '{"messages": [{"role": "user", "content": "Here is a file"}]}' - - mock_client.post.side_effect = httpx.RequestError( - "Request error occurred" - ) - - with pytest.raises( - Exception, - match="Request error occurred when sending chat media to the chat API: Request error occurred", - ): - await send_chat_media(mock_client, base_url, user_id, session_id, chat_data, temp_image_files) diff --git a/tests/test_client.py b/tests/test_client.py index 8995e09..e191ad3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -58,6 +58,29 @@ async def test_chat_success(): assert response == expected_response +@pytest.mark.asyncio +async def test_chat_success_with_files(temp_image_files): + user_id = "user123" + session_id = "session123" + content = "What is in these images?" + expected_response = "Response from chat with files" + + # Mocking the response for the chat request with files + with respx.mock() as mock: + mock.post(f"{base_url}/v1/chat").mock( + return_value=httpx.Response(200, text=expected_response) + ) + client = HiveAgentClient(base_url, version) + + response = await client.chat(user_id, session_id, content, files=temp_image_files) + assert response == expected_response + + # Ensure the files are part of the request + request_content = mock.calls[0].request.content + assert b"test1.png" in request_content + assert b"test2.png" in request_content + + @pytest.mark.asyncio async def test_chat_failure(): user_id = "user123" @@ -73,6 +96,22 @@ async def test_chat_failure(): assert "Failed to send chat message" in str(excinfo.value) +@pytest.mark.asyncio +async def test_chat_failure_with_files(temp_image_files): + user_id = "user123" + session_id = "session123" + content = "What is in these images?" + + # Mocking a failure response for the chat request with files + with respx.mock() as mock: + mock.post(f"{base_url}/v1/chat").mock(return_value=httpx.Response(400)) + client = HiveAgentClient(base_url, version) + + with pytest.raises(Exception) as excinfo: + await client.chat(user_id, session_id, content, files=temp_image_files) + assert "Failed to send chat message" in str(excinfo.value) + + @pytest.mark.asyncio async def test_get_chat_history_success(): user_id = "user123" @@ -181,38 +220,6 @@ async def test_get_all_chats_failure(): ) -@pytest.mark.asyncio -async def test_chat_media_success(temp_image_files): - user_id = "user123" - session_id = "session123" - chat_data = '{"messages": [{"role": "user", "content": "Here is a file"}]}' - expected_response = "Media uploaded and response generated" - - with respx.mock() as mock: - mock.post(f"{base_url}/v1/chat_media").mock( - return_value=httpx.Response(200, text=expected_response) - ) - - client = HiveAgentClient(base_url, version) - response = await client.chat_media(user_id, session_id, chat_data, temp_image_files) - assert response == expected_response - - -@pytest.mark.asyncio -async def test_chat_media_failure(temp_image_files): - user_id = "user123" - session_id = "session123" - chat_data = '{"messages": [{"role": "user", "content": "Here is a file"}]}' - - with respx.mock() as mock: - mock.post(f"{base_url}/v1/chat_media").mock(return_value=httpx.Response(400)) - - client = HiveAgentClient(base_url, version) - with pytest.raises(Exception) as excinfo: - await client.chat_media(user_id, session_id, chat_data, temp_image_files) - assert "Failed to send chat media" in str(excinfo.value) - - @pytest.mark.asyncio async def test_create_table_success(): table_name = "test_table" diff --git a/tutorial.md b/tutorial.md index 877bff3..efd750c 100644 --- a/tutorial.md +++ b/tutorial.md @@ -29,29 +29,14 @@ client = HiveAgentClient(base_url, version) To send a chat message using the `chat` method: ```python -async def send_message(user_id, session_id, content): +async def send_message(user_id, session_id, content, files): try: - response = await client.chat(user_id=user_id, session_id=session_id, content=content) + response = await client.chat(user_id=user_id, session_id=session_id, content=content, files=files) print("Chat response:", response) except Exception as e: print("Error:", e) ``` -## Sending Chat Messages with Media - -To send a chat message along with media files using the chat_media method: - -```python -async def send_message_with_media(user_id, session_id, chat_data, files): - try: - response = await client.chat_media(user_id=user_id, session_id=session_id, chat_data=chat_data, files=files) - print("Chat media response:", response) - except Exception as e: - print("Error:", e) -``` -Note: The files parameter should be a list of file paths to the media files you want to upload. The chat_data should be a JSON string that includes the message content. - - ## Getting Chat History To fetch the chat history, you can use the `get_chat_history` method: @@ -226,15 +211,17 @@ Here is how you might use the client in an asynchronous context: import asyncio async def main(): - await send_message("user123", "session123", "Hello, world!") - await fetch_chat_history("user123", "session123") - await fetch_all_chats("user123") - await send_message_with_media( + await send_message( "user123", "session123", - '{"messages": [{"role": "user", "content": "Here is a file"}]}', - ["path/to/file1.png", "path/to/file2.png"] + "Hello, world!", + [ + "path/to/file1.png", + UploadFile(filename="file1.png", file=open("path/to/file1.png", "rb")) + ] ) + await fetch_chat_history("user123", "session123") + await fetch_all_chats("user123") await create_new_table("my_table", {"id": "Integer", "name": "String"}) await insert_new_data("my_table", {"name": "Test"})