-
-
Notifications
You must be signed in to change notification settings - Fork 130
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: replace API call with build_cls_from_json_with_retry function (#548
) * refactor: Replace API call with build_cls_from_json_with_retry function * fix lint error * fix lint errors * lint * trigger
- Loading branch information
Showing
8 changed files
with
77 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .workflow import build_cls_from_json_with_retry | ||
|
||
__all__ = ["build_cls_from_json_with_retry"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import json | ||
from typing import Type, TypeVar | ||
|
||
from rdagent.core.exception import FormatError | ||
from rdagent.log import rdagent_logger as logger | ||
from rdagent.oai.llm_utils import APIBackend | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
def build_cls_from_json_with_retry( | ||
cls: Type[T], system_prompt: str, user_prompt: str, retry_n: int = 5, **kwargs: dict | ||
) -> T: | ||
""" | ||
Parameters | ||
---------- | ||
cls : Type[T] | ||
The class type to be instantiated with the response data. | ||
system_prompt : str | ||
The initial prompt provided to the system for context. | ||
user_prompt : str | ||
The prompt given by the user to guide the response generation. | ||
retry_n : int | ||
The number of attempts to retry in case of failure. | ||
**kwargs | ||
Additional keyword arguments passed to the API call. | ||
Returns | ||
------- | ||
T | ||
An instance of the specified class type created from the response data. | ||
""" | ||
for i in range(retry_n): | ||
# currently, it only handle exception caused by initial class | ||
resp = APIBackend().build_messages_and_create_chat_completion( | ||
user_prompt=user_prompt, system_prompt=system_prompt, json_mode=True, **kwargs # type: ignore[arg-type] | ||
) | ||
try: | ||
return cls(**json.loads(resp)) | ||
except Exception as e: | ||
logger.warning(f"Attempt {i + 1}: The previous attempt didn't work due to: {e}") | ||
user_prompt = user_prompt + f"\n\nAttempt {i + 1}: The previous attempt didn't work due to: {e}" | ||
else: | ||
raise FormatError("Unable to produce a JSON response that meets the specified requirements.") |