Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: move use_auto_chat_cache_seed_gen and init_chat_cache_seed to LLMSett… #444

Merged
merged 2 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions rdagent/core/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,6 @@ class RDAgentSettings(BaseSettings):
# TODO: (xiao) think it can be a separate config.
log_trace_path: str | None = None

# Behavior of returning answers to the same question when caching is enabled
use_auto_chat_cache_seed_gen: bool = False
"""
`_create_chat_completion_inner_function` provdies a feature to pass in a seed to affect the cache hash key
We want to enable a auto seed generator to get different default seed for `_create_chat_completion_inner_function`
if seed is not given.
So the cache will only not miss you ask the same question on same round.
"""
init_chat_cache_seed: int = 42

# azure document intelligence configs
azure_document_intelligence_key: str = ""
azure_document_intelligence_endpoint: str = ""
Expand Down
3 changes: 2 additions & 1 deletion rdagent/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from fuzzywuzzy import fuzz # type: ignore[import-untyped]

from rdagent.core.conf import RD_AGENT_SETTINGS
from rdagent.oai.llm_conf import LLM_SETTINGS


class RDAgentException(Exception): # noqa: N818
Expand Down Expand Up @@ -98,7 +99,7 @@ class CacheSeedGen:
"""

def __init__(self) -> None:
self.set_seed(RD_AGENT_SETTINGS.init_chat_cache_seed)
self.set_seed(LLM_SETTINGS.init_chat_cache_seed)

def set_seed(self, seed: int) -> None:
random.seed(seed)
Expand Down
10 changes: 10 additions & 0 deletions rdagent/oai/llm_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ class LLMSettings(BaseSettings):
prompt_cache_path: str = str(Path.cwd() / "prompt_cache.db")
max_past_message_include: int = 10

# Behavior of returning answers to the same question when caching is enabled
use_auto_chat_cache_seed_gen: bool = False
"""
`_create_chat_completion_inner_function` provdies a feature to pass in a seed to affect the cache hash key
We want to enable a auto seed generator to get different default seed for `_create_chat_completion_inner_function`
if seed is not given.
So the cache will only not miss you ask the same question on same round.
"""
init_chat_cache_seed: int = 42

# Chat configs
openai_api_key: str = "" # TODO: simplify the key design.
chat_openai_api_key: str = ""
Expand Down
3 changes: 1 addition & 2 deletions rdagent/oai/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import numpy as np
import tiktoken

from rdagent.core.conf import RD_AGENT_SETTINGS
from rdagent.core.utils import LLM_CACHE_SEED_GEN, SingletonBaseClass
from rdagent.log import LogColors
from rdagent.log import rdagent_logger as logger
Expand Down Expand Up @@ -596,7 +595,7 @@ def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915
To make retries useful, we need to enable a seed.
This seed is different from `self.chat_seed` for GPT. It is for the local cache mechanism enabled by RD-Agent locally.
"""
if seed is None and RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen:
if seed is None and LLM_SETTINGS.use_auto_chat_cache_seed_gen:
seed = LLM_CACHE_SEED_GEN.get_next_seed()

# TODO: we can add this function back to avoid so much `self.cfg.log_llm_chat_content`
Expand Down
14 changes: 6 additions & 8 deletions test/oai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,22 @@ def test_chat_cache(self) -> None:
- 2 pass
- cache is not missed & same question get different answer.
"""
from rdagent.core.conf import RD_AGENT_SETTINGS
from rdagent.core.utils import LLM_CACHE_SEED_GEN
from rdagent.oai.llm_conf import LLM_SETTINGS

system_prompt = "You are a helpful assistant."
user_prompt = f"Give me {2} random country names, list {2} cities in each country, and introduce them"

origin_value = (
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen,
LLM_SETTINGS.use_auto_chat_cache_seed_gen,
LLM_SETTINGS.use_chat_cache,
LLM_SETTINGS.dump_chat_cache,
)

LLM_SETTINGS.use_chat_cache = True
LLM_SETTINGS.dump_chat_cache = True

RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen = True
LLM_SETTINGS.use_auto_chat_cache_seed_gen = True

LLM_CACHE_SEED_GEN.set_seed(10)
response1 = APIBackend().build_messages_and_create_chat_completion(
Expand Down Expand Up @@ -110,7 +109,7 @@ def test_chat_cache(self) -> None:

# Reset, for other tests
(
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen,
LLM_SETTINGS.use_auto_chat_cache_seed_gen,
LLM_SETTINGS.use_chat_cache,
LLM_SETTINGS.dump_chat_cache,
) = origin_value
Expand All @@ -132,23 +131,22 @@ def test_chat_cache_multiprocess(self) -> None:
- 2 pass
- cache is not missed & same question get different answer.
"""
from rdagent.core.conf import RD_AGENT_SETTINGS
from rdagent.core.utils import LLM_CACHE_SEED_GEN, multiprocessing_wrapper
from rdagent.oai.llm_conf import LLM_SETTINGS

system_prompt = "You are a helpful assistant."
user_prompt = f"Give me {2} random country names, list {2} cities in each country, and introduce them"

origin_value = (
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen,
LLM_SETTINGS.use_auto_chat_cache_seed_gen,
LLM_SETTINGS.use_chat_cache,
LLM_SETTINGS.dump_chat_cache,
)

LLM_SETTINGS.use_chat_cache = True
LLM_SETTINGS.dump_chat_cache = True

RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen = True
LLM_SETTINGS.use_auto_chat_cache_seed_gen = True

func_calls = [(_worker, (system_prompt, user_prompt)) for _ in range(4)]

Expand All @@ -161,7 +159,7 @@ def test_chat_cache_multiprocess(self) -> None:

# Reset, for other tests
(
RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen,
LLM_SETTINGS.use_auto_chat_cache_seed_gen,
LLM_SETTINGS.use_chat_cache,
LLM_SETTINGS.dump_chat_cache,
) = origin_value
Expand Down
Loading