From 9f8fdb342c7230c9c10eb827769792b26aad3c68 Mon Sep 17 00:00:00 2001 From: Arnaud Flament Date: Mon, 29 Jan 2024 15:40:47 -0800 Subject: [PATCH] Fix timeout --- council/llm/azure_llm.py | 3 +- council/llm/openai_llm.py | 3 +- council/utils/env.py | 3 ++ tests/integration/llm/test_azure_llm.py | 41 +++++++++---------------- 4 files changed, 19 insertions(+), 31 deletions(-) diff --git a/council/llm/azure_llm.py b/council/llm/azure_llm.py index fc645886..801f4f0b 100644 --- a/council/llm/azure_llm.py +++ b/council/llm/azure_llm.py @@ -26,8 +26,7 @@ def post_request(self, payload: dict[str, Any]) -> httpx.Response: timeout = self.config.timeout.value try: - with httpx.Client() as client: - client.timeout.read = timeout + with httpx.Client(timeout=timeout) as client: return client.post(url=self._uri, headers=headers, params=params, json=payload) except TimeoutException as e: raise LLMCallTimeoutException(timeout) from e diff --git a/council/llm/openai_llm.py b/council/llm/openai_llm.py index 8622a3fa..89aa3f6d 100644 --- a/council/llm/openai_llm.py +++ b/council/llm/openai_llm.py @@ -25,8 +25,7 @@ def post_request(self, payload: dict[str, Any]) -> httpx.Response: timeout = self.config.timeout.unwrap() try: - with httpx.Client() as client: - client.timeout.read = timeout + with httpx.Client(timeout=timeout) as client: return client.post(url=uri, headers=self._headers, json=payload) except TimeoutException as e: raise LLMCallTimeoutException(timeout) from e diff --git a/council/utils/env.py b/council/utils/env.py index b31fca2e..d0491887 100644 --- a/council/utils/env.py +++ b/council/utils/env.py @@ -110,3 +110,6 @@ def _set(self, value: Optional[str]): os.environ.pop(self.name, None) if value is not None: os.environ[self.name] = value + + def __str__(self): + return f"Env var:`{self.name}` value:{self.value} (previous value: {self.previous_value})" diff --git a/tests/integration/llm/test_azure_llm.py b/tests/integration/llm/test_azure_llm.py index 1d3036c7..9438d47a 100644 --- a/tests/integration/llm/test_azure_llm.py +++ b/tests/integration/llm/test_azure_llm.py @@ -1,11 +1,10 @@ -import os import unittest import dotenv from council.contexts import LLMContext -from council.llm import AzureLLM, LLMMessage, LLMException -from council.utils import ParameterValueException +from council.llm import AzureLLM, LLMMessage, LLMException, LLMCallTimeoutException +from council.utils import ParameterValueException, OsEnviron class TestLlmAzure(unittest.TestCase): @@ -39,41 +38,29 @@ def test_censored_prompt(self): self.assertIn("censored", str(e)) def test_max_token(self): - os.environ["AZURE_LLM_MAX_TOKENS"] = "5" - - try: + with OsEnviron("AZURE_LLM_MAX_TOKENS", "5"): llm = AzureLLM.from_env() messages = [LLMMessage.user_message("Give me an example of a currency")] result = llm.post_chat_request(LLMContext.empty(), messages) self.assertTrue(len(result.first_choice.replace(" ", "")) <= 5 * 5) - finally: - del os.environ["AZURE_LLM_MAX_TOKENS"] - - self.assertEquals(os.getenv("AZURE_LLM_MAX_TOKENS"), None) def test_choices(self): - os.environ["AZURE_LLM_N"] = "3" - os.environ["AZURE_LLM_TEMPERATURE"] = "1.0" - - try: + with OsEnviron("AZURE_LLM_N", "3"), OsEnviron("AZURE_LLM_TEMPERATURE", "1.0"): llm = AzureLLM.from_env() messages = [LLMMessage.user_message("Give me an example of a currency")] result = llm.post_chat_request(LLMContext.empty(), messages) self.assertEquals(3, len(result.choices)) [print("\n- Choice:" + choice) for choice in result.choices] - finally: - del os.environ["AZURE_LLM_N"] - del os.environ["AZURE_LLM_TEMPERATURE"] - - self.assertEquals(os.getenv("AZURE_LLM_N"), None) - self.assertEquals(os.getenv("AZURE_LLM_TEMPERATURE"), None) def test_invalid_temperature(self): - os.environ["AZURE_LLM_TEMPERATURE"] = "3.5" - - with self.assertRaises(ParameterValueException) as cm: - _ = AzureLLM.from_env() - print(cm.exception) - del os.environ["AZURE_LLM_TEMPERATURE"] + with OsEnviron("AZURE_LLM_TEMPERATURE", "3.5"): + with self.assertRaises(ParameterValueException) as cm: + _ = AzureLLM.from_env() + print(cm.exception) - self.assertEquals(os.getenv("AZURE_LLM_TEMPERATURE"), None) + def test_time_out(self): + with OsEnviron("AZURE_LLM_TIMEOUT", "1"): + llm = AzureLLM.from_env() + messages = [LLMMessage.user_message("Give a full explanation of quantum intrication ")] + with self.assertRaises(LLMCallTimeoutException): + _ = llm.post_chat_request(LLMContext.empty(), messages)