Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix timeout #128

Merged
merged 1 commit into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions council/llm/azure_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ def post_request(self, payload: dict[str, Any]) -> httpx.Response:

timeout = self.config.timeout.value
try:
with httpx.Client() as client:
client.timeout.read = timeout
with httpx.Client(timeout=timeout) as client:
return client.post(url=self._uri, headers=headers, params=params, json=payload)
except TimeoutException as e:
raise LLMCallTimeoutException(timeout) from e
Expand Down
3 changes: 1 addition & 2 deletions council/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def post_request(self, payload: dict[str, Any]) -> httpx.Response:

timeout = self.config.timeout.unwrap()
try:
with httpx.Client() as client:
client.timeout.read = timeout
with httpx.Client(timeout=timeout) as client:
return client.post(url=uri, headers=self._headers, json=payload)
except TimeoutException as e:
raise LLMCallTimeoutException(timeout) from e
Expand Down
3 changes: 3 additions & 0 deletions council/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,6 @@ def _set(self, value: Optional[str]):
os.environ.pop(self.name, None)
if value is not None:
os.environ[self.name] = value

def __str__(self):
return f"Env var:`{self.name}` value:{self.value} (previous value: {self.previous_value})"
41 changes: 14 additions & 27 deletions tests/integration/llm/test_azure_llm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import os
import unittest

import dotenv

from council.contexts import LLMContext
from council.llm import AzureLLM, LLMMessage, LLMException
from council.utils import ParameterValueException
from council.llm import AzureLLM, LLMMessage, LLMException, LLMCallTimeoutException
from council.utils import ParameterValueException, OsEnviron


class TestLlmAzure(unittest.TestCase):
Expand Down Expand Up @@ -39,41 +38,29 @@ def test_censored_prompt(self):
self.assertIn("censored", str(e))

def test_max_token(self):
os.environ["AZURE_LLM_MAX_TOKENS"] = "5"

try:
with OsEnviron("AZURE_LLM_MAX_TOKENS", "5"):
llm = AzureLLM.from_env()
messages = [LLMMessage.user_message("Give me an example of a currency")]
result = llm.post_chat_request(LLMContext.empty(), messages)
self.assertTrue(len(result.first_choice.replace(" ", "")) <= 5 * 5)
finally:
del os.environ["AZURE_LLM_MAX_TOKENS"]

self.assertEquals(os.getenv("AZURE_LLM_MAX_TOKENS"), None)

def test_choices(self):
os.environ["AZURE_LLM_N"] = "3"
os.environ["AZURE_LLM_TEMPERATURE"] = "1.0"

try:
with OsEnviron("AZURE_LLM_N", "3"), OsEnviron("AZURE_LLM_TEMPERATURE", "1.0"):
llm = AzureLLM.from_env()
messages = [LLMMessage.user_message("Give me an example of a currency")]
result = llm.post_chat_request(LLMContext.empty(), messages)
self.assertEquals(3, len(result.choices))
[print("\n- Choice:" + choice) for choice in result.choices]
finally:
del os.environ["AZURE_LLM_N"]
del os.environ["AZURE_LLM_TEMPERATURE"]

self.assertEquals(os.getenv("AZURE_LLM_N"), None)
self.assertEquals(os.getenv("AZURE_LLM_TEMPERATURE"), None)

def test_invalid_temperature(self):
os.environ["AZURE_LLM_TEMPERATURE"] = "3.5"

with self.assertRaises(ParameterValueException) as cm:
_ = AzureLLM.from_env()
print(cm.exception)
del os.environ["AZURE_LLM_TEMPERATURE"]
with OsEnviron("AZURE_LLM_TEMPERATURE", "3.5"):
with self.assertRaises(ParameterValueException) as cm:
_ = AzureLLM.from_env()
print(cm.exception)

self.assertEquals(os.getenv("AZURE_LLM_TEMPERATURE"), None)
def test_time_out(self):
with OsEnviron("AZURE_LLM_TIMEOUT", "1"):
llm = AzureLLM.from_env()
messages = [LLMMessage.user_message("Give a full explanation of quantum intrication ")]
with self.assertRaises(LLMCallTimeoutException):
_ = llm.post_chat_request(LLMContext.empty(), messages)
Loading