diff --git a/nano_graphrag/_llm.py b/nano_graphrag/_llm.py index da61d54..21ffe82 100644 --- a/nano_graphrag/_llm.py +++ b/nano_graphrag/_llm.py @@ -1,6 +1,6 @@ import numpy as np -from openai import AsyncOpenAI,AsyncAzureOpenAI, APIConnectionError, RateLimitError +from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError from tenacity import ( retry, @@ -84,8 +84,6 @@ async def openai_embedding(texts: list[str]) -> np.ndarray: return np.array([dp.embedding for dp in response.data]) - - @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -94,9 +92,7 @@ async def openai_embedding(texts: list[str]) -> np.ndarray: async def azure_openai_complete_if_cache( deployment_name, prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: - azure_openai_client = AsyncAzureOpenAI( - - ) + azure_openai_client = AsyncAzureOpenAI() hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] if system_prompt: @@ -109,17 +105,22 @@ async def azure_openai_complete_if_cache( if if_cache_return is not None: return if_cache_return["return"] - async with azure_openai_client as client: - response = await client.chat.completions.create( - model=deployment_name, messages=messages, **kwargs - ) + response = await azure_openai_client.chat.completions.create( + model=deployment_name, messages=messages, **kwargs + ) if hashing_kv is not None: await hashing_kv.upsert( - {args_hash: {"return": response.choices[0].message.content, "model": deployment_name}} + { + args_hash: { + "return": response.choices[0].message.content, + "model": deployment_name, + } + } ) return response.choices[0].message.content + async def azure_gpt_4o_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -130,6 +131,8 @@ async def azure_gpt_4o_complete( history_messages=history_messages, **kwargs, ) + + async def azure_gpt_4o_mini_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -148,14 +151,13 @@ async def azure_gpt_4o_mini_complete( wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((RateLimitError, APIConnectionError)), ) -async def azure_openai_embedding(texts: list[str], **kwargs) -> np.ndarray: +async def azure_openai_embedding(texts: list[str]) -> np.ndarray: azure_openai_client = AsyncAzureOpenAI( api_key=os.environ.get("API_KEY_EMB"), api_version=os.environ.get("API_VERSION_EMB"), - azure_endpoint=os.environ.get("AZURE_ENDPOINT_EMB") + azure_endpoint=os.environ.get("AZURE_ENDPOINT_EMB"), + ) + response = await azure_openai_client.embeddings.create( + model="text-embedding-3-small", input=texts, encoding_format="float" ) - async with azure_openai_client as client: - response = await client.embeddings.create( - model="text-embedding-3-small", input=texts, **kwargs - ) return np.array([dp.embedding for dp in response.data]) diff --git a/tests/test_openai.py b/tests/test_openai.py new file mode 100644 index 0000000..afc5eea --- /dev/null +++ b/tests/test_openai.py @@ -0,0 +1,116 @@ +import pytest +import numpy as np +from unittest.mock import AsyncMock, Mock, patch +from nano_graphrag import _llm + + +@pytest.fixture +def mock_openai_client(): + with patch("nano_graphrag._llm.AsyncOpenAI") as mock_openai: + mock_client = AsyncMock() + mock_openai.return_value = mock_client + yield mock_client + + +@pytest.fixture +def mock_azure_openai_client(): + with patch("nano_graphrag._llm.AsyncAzureOpenAI") as mock_openai: + mock_client = AsyncMock() + mock_openai.return_value = mock_client + yield mock_client + + +@pytest.mark.asyncio +async def test_openai_gpt4o(mock_openai_client): + mock_response = AsyncMock() + mock_response.choices = [Mock(message=Mock(content="1"))] + messages = [{"role": "system", "content": "3"}, {"role": "user", "content": "2"}] + mock_openai_client.chat.completions.create.return_value = mock_response + + response = await _llm.gpt_4o_complete("2", system_prompt="3") + + mock_openai_client.chat.completions.create.assert_awaited_once_with( + model="gpt-4o", + messages=messages, + ) + assert response == "1" + + +@pytest.mark.asyncio +async def test_openai_gpt4o_mini(mock_openai_client): + mock_response = AsyncMock() + mock_response.choices = [Mock(message=Mock(content="1"))] + messages = [{"role": "system", "content": "3"}, {"role": "user", "content": "2"}] + mock_openai_client.chat.completions.create.return_value = mock_response + + response = await _llm.gpt_4o_mini_complete("2", system_prompt="3") + + mock_openai_client.chat.completions.create.assert_awaited_once_with( + model="gpt-4o-mini", + messages=messages, + ) + assert response == "1" + + +@pytest.mark.asyncio +async def test_azure_openai_gpt4o(mock_azure_openai_client): + mock_response = AsyncMock() + mock_response.choices = [Mock(message=Mock(content="1"))] + messages = [{"role": "system", "content": "3"}, {"role": "user", "content": "2"}] + mock_azure_openai_client.chat.completions.create.return_value = mock_response + + response = await _llm.azure_gpt_4o_complete("2", system_prompt="3") + + mock_azure_openai_client.chat.completions.create.assert_awaited_once_with( + model="gpt-4o", + messages=messages, + ) + assert response == "1" + + +@pytest.mark.asyncio +async def test_azure_openai_gpt4o_mini(mock_azure_openai_client): + mock_response = AsyncMock() + mock_response.choices = [Mock(message=Mock(content="1"))] + messages = [{"role": "system", "content": "3"}, {"role": "user", "content": "2"}] + mock_azure_openai_client.chat.completions.create.return_value = mock_response + + response = await _llm.azure_gpt_4o_mini_complete("2", system_prompt="3") + + mock_azure_openai_client.chat.completions.create.assert_awaited_once_with( + model="gpt-4o-mini", + messages=messages, + ) + assert response == "1" + + +@pytest.mark.asyncio +async def test_openai_embedding(mock_openai_client): + mock_response = AsyncMock() + mock_response.data = [Mock(embedding=[1, 1, 1])] + texts = ["Hello world"] + mock_openai_client.embeddings.create.return_value = mock_response + + response = await _llm.openai_embedding(texts) + + mock_openai_client.embeddings.create.assert_awaited_once_with( + model="text-embedding-3-small", input=texts, encoding_format="float" + ) + # print(response) + assert np.allclose(response, np.array([[1, 1, 1]])) + + +@pytest.mark.asyncio +async def test_azure_openai_embedding(mock_azure_openai_client): + mock_response = AsyncMock() + mock_response.data = [Mock(embedding=[1, 1, 1])] + texts = ["Hello world"] + mock_azure_openai_client.embeddings.create.return_value = mock_response + + response = await _llm.azure_openai_embedding(texts) + + mock_azure_openai_client.embeddings.create.assert_awaited_once_with( + model="text-embedding-3-small", input=texts, encoding_format="float" + ) + # print(response) + assert np.allclose(response, np.array([[1, 1, 1]]))