diff --git a/aide/backend/__init__.py b/aide/backend/__init__.py index 208956a..77f6145 100644 --- a/aide/backend/__init__.py +++ b/aide/backend/__init__.py @@ -33,6 +33,14 @@ def query( "max_tokens": max_tokens, } + # Handle models with beta limitations + # ref: https://platform.openai.com/docs/guides/reasoning/beta-limitations + if model.startswith("o1-"): + if system_message: + user_message = system_message + system_message = None + model_kwargs["temperature"] = 1 + query_func = backend_anthropic.query if "claude-" in model else backend_openai.query output, req_time, in_tok_count, out_tok_count, info = query_func( system_message=compile_prompt_to_md(system_message) if system_message else None, diff --git a/aide/backend/backend_anthropic.py b/aide/backend/backend_anthropic.py index 6b11a74..a28cc93 100644 --- a/aide/backend/backend_anthropic.py +++ b/aide/backend/backend_anthropic.py @@ -2,23 +2,25 @@ import time -from anthropic import Anthropic, RateLimitError -from .utils import FunctionSpec, OutputType, opt_messages_to_list -from funcy import notnone, once, retry, select_values +from .utils import FunctionSpec, OutputType, opt_messages_to_list, backoff_create +from funcy import notnone, once, select_values +import anthropic -_client: Anthropic = None # type: ignore +_client: anthropic.Anthropic = None # type: ignore -RATELIMIT_RETRIES = 5 -retry_exp = retry(RATELIMIT_RETRIES, errors=RateLimitError, timeout=lambda a: 2 ** (a + 1)) # type: ignore +ANTHROPIC_TIMEOUT_EXCEPTIONS = ( + anthropic.RateLimitError, + anthropic.APIConnectionError, + anthropic.APITimeoutError, + anthropic.InternalServerError, +) @once def _setup_anthropic_client(): global _client - _client = Anthropic() + _client = anthropic.Anthropic(max_retries=0) - -@retry_exp def query( system_message: str | None, user_message: str | None, @@ -48,7 +50,12 @@ def query( messages = opt_messages_to_list(None, user_message) t0 = time.time() - message = _client.messages.create(messages=messages, **filtered_kwargs) # type: ignore + message = backoff_create( + _client.messages.create, + ANTHROPIC_TIMEOUT_EXCEPTIONS, + messages=messages, + **filtered_kwargs, + ) req_time = time.time() - t0 assert len(message.content) == 1 and message.content[0].type == "text" diff --git a/aide/backend/backend_openai.py b/aide/backend/backend_openai.py index a8829d4..a69aade 100644 --- a/aide/backend/backend_openai.py +++ b/aide/backend/backend_openai.py @@ -4,25 +4,26 @@ import logging import time -from .utils import FunctionSpec, OutputType, opt_messages_to_list -from funcy import notnone, once, retry, select_values -from openai import OpenAI, RateLimitError +from .utils import FunctionSpec, OutputType, opt_messages_to_list, backoff_create +from funcy import notnone, once, select_values +import openai logger = logging.getLogger("aide") -_client: OpenAI = None # type: ignore - -RATELIMIT_RETRIES = 5 -retry_exp = retry(RATELIMIT_RETRIES, errors=RateLimitError, timeout=lambda a: 2 ** (a + 1)) # type: ignore +_client: openai.OpenAI = None # type: ignore +OPENAI_TIMEOUT_EXCEPTIONS = ( + openai.RateLimitError, + openai.APIConnectionError, + openai.APITimeoutError, + openai.InternalServerError, +) @once def _setup_openai_client(): global _client - _client = OpenAI(max_retries=3) - + _client = openai.OpenAI(max_retries=0) -@retry_exp def query( system_message: str | None, user_message: str | None, @@ -40,7 +41,12 @@ def query( filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict t0 = time.time() - completion = _client.chat.completions.create(messages=messages, **filtered_kwargs) # type: ignore + completion = backoff_create( + _client.chat.completions.create, + OPENAI_TIMEOUT_EXCEPTIONS, + messages=messages, + **filtered_kwargs, + ) req_time = time.time() - t0 choice = completion.choices[0] diff --git a/aide/backend/utils.py b/aide/backend/utils.py index 1e2de3d..d223b1b 100644 --- a/aide/backend/utils.py +++ b/aide/backend/utils.py @@ -8,6 +8,27 @@ OutputType = str | FunctionCallType +import backoff +import logging +from typing import Callable + +logger = logging.getLogger("aide") + + +@backoff.on_predicate( + wait_gen=backoff.expo, + max_value=60, + factor=1.5, +) +def backoff_create( + create_fn: Callable, retry_exceptions: list[Exception], *args, **kwargs +): + try: + return create_fn(*args, **kwargs) + except retry_exceptions as e: + logger.info(f"Backoff exception: {e}") + return False + def opt_messages_to_list( system_message: str | None, user_message: str | None ) -> list[dict[str, str]]: diff --git a/requirements.txt b/requirements.txt index 5248639..6c2c6c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -88,4 +88,5 @@ pdf2image PyPDF pyocr pyarrow -xlrd \ No newline at end of file +xlrd +backoff