Skip to content

Commit

Permalink
testing only what the user sees in research tests, also testing with …
Browse files Browse the repository at this point in the history
…dynamic data
  • Loading branch information
nerfZael committed May 7, 2024
1 parent 75104cb commit 380d828
Show file tree
Hide file tree
Showing 14 changed files with 125 additions and 93 deletions.
46 changes: 35 additions & 11 deletions autotx/AutoTx.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class RunResult:
end_reason: EndReason
total_cost_without_cache: float
total_cost_with_cache: float
info_messages: list[str]

class AutoTx:
manager: SafeManager
Expand All @@ -57,6 +58,7 @@ class AutoTx:
max_rounds: int
current_run_cost_without_cache: float = 0
current_run_cost_with_cache: float = 0
info_messages: list[str] = []

def __init__(
self,
Expand All @@ -80,13 +82,13 @@ def __init__(
def run(self, prompt: str, non_interactive: bool, summary_method: str = "last_msg") -> RunResult:
total_cost_without_cache: float = 0
total_cost_with_cache: float = 0
info_messages = []

while True:
self.current_run_costs_without_cache = 0
self.current_run_costs_with_cache = 0
result = self.try_run(prompt, non_interactive, summary_method)
total_cost_without_cache += result.total_cost_without_cache + self.current_run_cost_without_cache
total_cost_with_cache += result.total_cost_with_cache + self.current_run_cost_with_cache
info_messages += result.info_messages

if result.end_reason == EndReason.TERMINATE or non_interactive:
if self.log_costs:
Expand All @@ -98,14 +100,29 @@ def run(self, prompt: str, non_interactive: bool, summary_method: str = "last_ms
with open(f"costs/{now_str}.txt", "w") as f:
f.write(str(total_cost_without_cache))

return result
return RunResult(
result.summary,
result.chat_history_json,
result.transactions,
result.end_reason,
total_cost_without_cache,
total_cost_with_cache,
info_messages
)
else:
cprint("Prompt not supported. Please provide a new prompt.", "yellow")
prompt_not_supported = "Prompt not supported. Please provide a new prompt."

cprint(prompt_not_supported, "yellow")
info_messages.append(prompt_not_supported)

prompt = input("Enter a new prompt: ")

def try_run(self, prompt: str, non_interactive: bool, summary_method: str = "last_msg") -> RunResult:
original_prompt = prompt
past_runs: list[PastRun] = []
self.current_run_costs_without_cache = 0
self.current_run_costs_with_cache = 0
self.info_messages = []
self.logger.start()

while True:
Expand All @@ -128,12 +145,12 @@ def try_run(self, prompt: str, non_interactive: bool, summary_method: str = "las
+ prev_history
+ "Pay close attention to the user's feedback and try again.\n")

print("Running AutoTx with the following prompt: ", prompt)
self.notify_user("Running AutoTx with the following prompt: " + prompt)

agents_information = self.get_agents_information(self.agents)

user_proxy_agent = user_proxy.build(prompt, agents_information, self.get_llm_config)
clarifier_agent = clarifier.build(user_proxy_agent, agents_information, not non_interactive, self.get_llm_config)
clarifier_agent = clarifier.build(user_proxy_agent, agents_information, not non_interactive, self.get_llm_config, self.notify_user)

helper_agents: list[AutogenAgent] = [
user_proxy_agent,
Expand All @@ -142,7 +159,7 @@ def try_run(self, prompt: str, non_interactive: bool, summary_method: str = "las
if not non_interactive:
helper_agents.append(clarifier_agent)

autogen_agents = [agent.build_autogen_agent(self, user_proxy_agent, self.get_llm_config()) for agent in self.agents]
autogen_agents = [agent.build_autogen_agent(self, user_proxy_agent, self.get_llm_config(), self.notify_user) for agent in self.agents]

manager_agent = manager.build(autogen_agents + helper_agents, self.max_rounds, not non_interactive, self.get_llm_config)

Expand All @@ -159,9 +176,9 @@ def try_run(self, prompt: str, non_interactive: bool, summary_method: str = "las

if "ERROR:" in chat.summary:
error_message = chat.summary.replace("ERROR: ", "").replace("\n", "")
cprint(error_message, "red")
self.notify_user(error_message, "red")
else:
cprint(chat.summary.replace("\n", ""), "green")
self.notify_user(chat.summary.replace("\n", ""), "green")

is_goal_supported = chat.chat_history[-1]["content"] != "Goal not supported: TERMINATE"

Expand All @@ -181,7 +198,7 @@ def try_run(self, prompt: str, non_interactive: bool, summary_method: str = "las
break

except Exception as e:
cprint(e, "red")
self.notify_user(e, "red")
break

self.logger.stop()
Expand All @@ -192,7 +209,14 @@ def try_run(self, prompt: str, non_interactive: bool, summary_method: str = "las

chat_history = json.dumps(chat.chat_history, indent=4)

return RunResult(chat.summary, chat_history, transactions, EndReason.TERMINATE if is_goal_supported else EndReason.GOAL_NOT_SUPPORTED, float(chat.cost["usage_including_cached_inference"]["total_cost"]), float(chat.cost["usage_excluding_cached_inference"]["total_cost"]))
return RunResult(chat.summary, chat_history, transactions, EndReason.TERMINATE if is_goal_supported else EndReason.GOAL_NOT_SUPPORTED, float(chat.cost["usage_including_cached_inference"]["total_cost"]), float(chat.cost["usage_excluding_cached_inference"]["total_cost"]), self.info_messages)

def notify_user(self, message: object, color: str | None = None):
if color:
cprint(message, color)
else:
print(message)
self.info_messages.append(message)

def get_agents_information(self, agents: list[AutoTxAgent]) -> str:
agent_descriptions = []
Expand Down
6 changes: 4 additions & 2 deletions autotx/agents/DelegateResearchTokensAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def build_tool(self, autotx: AutoTx) -> Callable[[str], str]:
def run(
tasks: Annotated[str, "User tasks to research"]
) -> str:
print(f"Researching user task:", tasks)
autotx.notify_user(f"Researching user tasks: " + tasks)

user_proxy_agent = UserProxyAgent(
name="user_proxy",
Expand Down Expand Up @@ -107,7 +107,7 @@ def run(
code_execution_config=False,
)

research_agent = ResearchTokensAgent().build_autogen_agent(autotx, user_proxy_agent, autotx.get_llm_config())
research_agent = ResearchTokensAgent().build_autogen_agent(autotx, user_proxy_agent, autotx.get_llm_config(), autotx.notify_user)

chat = user_proxy_agent.initiate_chat(
research_agent,
Expand All @@ -125,6 +125,8 @@ def run(

summary = aggregate_chat_responses(chat)

autotx.notify_user("Finished researching user tasks")

return summary

return run
Expand Down
2 changes: 1 addition & 1 deletion autotx/agents/ExampleAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def run(
receiver: Annotated[str, "The receiver of something."]
) -> str:
# TODO: do something useful
print(f"ExampleTool run: {amount} {receiver}")
autotx.notify_user(f"ExampleTool run: {amount} {receiver}")

# NOTE: you can add transactions to AutoTx's current bundle
# autotx.transactions.append(tx)
Expand Down
8 changes: 4 additions & 4 deletions autotx/agents/ResearchTokensAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def build_tool(self, autotx: AutoTx) -> Callable[[str], str]:
def run(
token_id: Annotated[str, "ID of token"]
) -> str:
print(f"Fetching token information for {token_id}")
autotx.notify_user(f"Fetching token information for {token_id}")

token_information = get_coingecko().coins.get_id(
id=token_id,
Expand Down Expand Up @@ -166,7 +166,7 @@ def run(
token_symbol: Annotated[str, "Symbol of token to search"],
retrieve_duplicate: Annotated[bool, "Set to true to retrieve all instances of tokens sharing the same symbol, indicating potential duplicates. By default, it is False, meaning only a single, most relevant token is retrieved unless duplication is explicitly requested."]
) -> str:
print(f"Searching for token with symbol: {token_symbol}")
autotx.notify_user(f"Searching for token with symbol: {token_symbol}")

response = get_coingecko().search.get(token_symbol)

Expand All @@ -184,7 +184,7 @@ class GetAvailableCategoriesTool(AutoTxTool):

def build_tool(self, autotx: AutoTx) -> Callable[[], str]:
def run() -> str:
print("Fetching available token categories")
autotx.notify_user("Fetching available token categories")

categories = get_coingecko().categories.get_list()
return json.dumps([category["category_id"] for category in categories])
Expand All @@ -203,7 +203,7 @@ def run(
price_change_percentage_interval: Annotated[str, "Interval of time in price change percentage. It can be: '1h' | '24h' | '7d' | '14d' | '30d' | '200d' | '1y'. '24h' is the default"],
network_name: Annotated[Optional[str], f"Possible values include: {SUPPORTED_NETWORKS_AS_STRING}. Use this parameter only if you require analysis for a specific network. Otherwise, pass an empty string"]
) -> str:
print(f"Fetching tokens from category: {category}")
autotx.notify_user(f"Fetching tokens from category: {category}")

try:
tokens_in_category = get_coingecko().coins.get_markets(
Expand Down
4 changes: 2 additions & 2 deletions autotx/agents/SendTokensAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def run(

autotx.transactions.append(prepared_tx)

print(f"Prepared transaction: {prepared_tx.summary}")
autotx.notify_user(f"Prepared transaction: {prepared_tx.summary}")

return prepared_tx.summary

Expand Down Expand Up @@ -131,7 +131,7 @@ def run(
else:
balance = get_erc20_balance(web3, token_address, owner_addr)

print(f"Fetching {token} balance for {str(owner_addr)}: {balance} {token}")
autotx.notify_user(f"Fetching {token} balance for {str(owner_addr)}: {balance} {token}")

return balance

Expand Down
3 changes: 3 additions & 0 deletions autotx/agents/SwapTokensAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ def run(
for i, tx in enumerate(autotx.transactions)
]
)

autotx.notify_user(summary)

return dedent(
f"""
{summary}
Expand Down
7 changes: 3 additions & 4 deletions autotx/autotx_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union
import autogen
from termcolor import cprint
if TYPE_CHECKING:
from autotx.autotx_tool import AutoTxTool
from autotx.AutoTx import AutoTx
Expand All @@ -17,7 +16,7 @@ def __init__(self) -> None:
f"{tool.name}: {tool.description}" for tool in self.tools
]

def build_autogen_agent(self, autotx: 'AutoTx', user_proxy: autogen.UserProxyAgent, llm_config: Optional[Dict[str, Any]]) -> autogen.Agent:
def build_autogen_agent(self, autotx: 'AutoTx', user_proxy: autogen.UserProxyAgent, llm_config: Optional[Dict[str, Any]], notify_user: Callable[[object, str | None], None]) -> autogen.Agent:
system_message = None
if isinstance(self.system_message, str):
system_message = self.system_message
Expand All @@ -44,9 +43,9 @@ def send_message_hook(
) -> Union[Dict[str, Any], str]:
if recipient.name == "chat_manager" and message != "TERMINATE":
if isinstance(message, str):
cprint(message, "light_yellow")
notify_user(message, "light_yellow")
elif message["content"] != None:
cprint(message["content"], "light_yellow")
notify_user(message["content"], "light_yellow")
return message

agent.register_hook(
Expand Down
9 changes: 3 additions & 6 deletions autotx/helper_agents/clarifier.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from textwrap import dedent
from typing import Annotated, Any, Callable, Dict, Optional
from autogen import UserProxyAgent, AssistantAgent, Agent as AutogenAgent
from termcolor import cprint
from autogen import UserProxyAgent, AssistantAgent

from autotx.utils.ethereum.eth_address import ETHAddress

def build(user_proxy: UserProxyAgent, agents_information: str, interactive: bool, get_llm_config: Callable[[], Optional[Dict[str, Any]]]) -> AssistantAgent:
def build(user_proxy: UserProxyAgent, agents_information: str, interactive: bool, get_llm_config: Callable[[], Optional[Dict[str, Any]]], notify_user: Callable[[object, str | None], None]) -> AssistantAgent:
missing_1 = dedent("""
If the goal is not clear or missing information, you MUST ask for more information by calling the request_user_input tool.
Always ensure you have all the information needed to define the goal that can be executed without prior context.
Expand Down Expand Up @@ -73,7 +70,7 @@ def request_user_input(
def goal_outside_scope(
message: Annotated[str, "The message return to the user about why the goal is outside of the supported scope"],
) -> str:
cprint(f"Goal not supported: {message}", "red")
notify_user(f"Goal not supported: {message}", "red")
return "Goal not supported: TERMINATE"

clarifier_agent.register_for_llm(name="goal_outside_scope", description="Notify the user about their goal not being in the scope of the agents")(goal_outside_scope)
Expand Down
1 change: 1 addition & 0 deletions autotx/helper_agents/user_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def build(user_prompt: str, agents_information: str, get_llm_config: Callable[[]
If you encounter an error, try to resolve it (either yourself of with other agents) and only respond with 'TERMINATE' if the goal is truly not achievable.
Try to find an alternative solution if the goal is not achievable.
If a token is not supported, ask the 'research-tokens' agent to find a supported token (if it fits within the user's goal).
Before you end the conversation, make sure to summarize the results.
"""
),
description="user_proxy is an agent authorized to act on behalf of the user.",
Expand Down
21 changes: 16 additions & 5 deletions autotx/tests/agents/token/research/test_advanced.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from autotx.tests.agents.token.research.test_research import get_top_token_addresses_by_market_cap
from autotx.utils.ethereum.eth_address import ETHAddress

def test_research_and_swap_many_tokens_subjective_simple(configuration, auto_tx):
Expand All @@ -15,6 +16,15 @@ def test_research_and_swap_many_tokens_subjective_simple(configuration, auto_tx)

ending_balance = manager.balance_of()

gaming_token_address = get_top_token_addresses_by_market_cap("gaming", "MAINNET", 1, auto_tx)[0]
gaming_token_balance_in_safe = manager.balance_of(gaming_token_address)

ai_token_address = get_top_token_addresses_by_market_cap("ai-themed-coins", "MAINNET", 1, auto_tx)[0]
ai_token_balance_in_safe = manager.balance_of(ai_token_address)

meme_token_address = get_top_token_addresses_by_market_cap("meme-token", "MAINNET", 1, auto_tx)[0]
meme_token_balance_in_safe = manager.balance_of(meme_token_address)

# Verify the balance is lower by max 3 ETH
assert starting_balance - ending_balance <= 3
# Verify there are at least 3 transactions
Expand All @@ -24,12 +34,13 @@ def test_research_and_swap_many_tokens_subjective_simple(configuration, auto_tx)
# Verify the tokens are different
assert len(set([tx.summary.split(" ")[-1] for tx in result.transactions])) == 3

# Verify the tokens are in the safe
assert gaming_token_balance_in_safe > 0
assert ai_token_balance_in_safe > 0
assert meme_token_balance_in_safe > 0

def test_research_and_swap_many_tokens_subjective_complex(configuration, auto_tx):
(_, _, _, manager) = configuration
uni_address = ETHAddress(auto_tx.network.tokens["uni"])

uni_balance_in_safe = manager.balance_of(uni_address)
assert uni_balance_in_safe == 0

starting_balance = manager.balance_of()

Expand All @@ -46,4 +57,4 @@ def test_research_and_swap_many_tokens_subjective_complex(configuration, auto_tx
# Verify there are only swap transactions
assert all([tx.summary.startswith("Swap") for tx in result.transactions])
# Verify the tokens are different
assert len(set([tx.summary.split(" ")[-1] for tx in result.transactions])) == 10
assert len(set([tx.summary.split(" ")[-1] for tx in result.transactions])) == 10
19 changes: 14 additions & 5 deletions autotx/tests/agents/token/research/test_research.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,17 @@
get_coingecko,
)

from autotx.utils.ethereum.eth_address import ETHAddress
from autotx.utils.ethereum.networks import ChainId

def get_top_token_addresses_by_market_cap(category: str, network: str, count: int, auto_tx) -> list[ETHAddress]:
tokens = get_coingecko().coins.get_markets(vs_currency="usd", category=category, per_page=250)
tokens_in_network = filter_token_list_by_network(
tokens, network
)

return [ETHAddress(auto_tx.network.tokens[token["symbol"].lower()]) for token in tokens_in_network[:count]]

def test_price_change_information(auto_tx):
token_information = get_coingecko().coins.get_id(
id="starknet",
Expand All @@ -21,7 +30,7 @@ def test_price_change_information(auto_tx):
price_change_rounded = round(price_change, 2)

assert (
str(price_change) in result.chat_history_json or str(price_change_rounded) in result.chat_history_json
str(price_change) in "\n".join(result.info_messages).lower() or str(price_change_rounded) in "\n".join(result.info_messages).lower()
)

def test_get_top_5_tokens_from_base(auto_tx):
Expand All @@ -34,7 +43,7 @@ def test_get_top_5_tokens_from_base(auto_tx):

for token in tokens[:5]:
symbol: str = token["symbol"]
assert symbol.lower() in result.chat_history_json.lower()
assert symbol.lower() in "\n".join(result.info_messages).lower()

def test_get_top_5_most_traded_tokens_from_l1(auto_tx):
tokens = get_coingecko().coins.get_markets(
Expand All @@ -46,7 +55,7 @@ def test_get_top_5_most_traded_tokens_from_l1(auto_tx):

for token in tokens[:5]:
symbol: str = token["symbol"]
assert symbol.lower() in result.chat_history_json.lower()
assert symbol.lower() in "\n".join(result.info_messages).lower()

def test_get_top_5_memecoins(auto_tx):
tokens = get_coingecko().coins.get_markets(vs_currency="usd", category="meme-token")
Expand All @@ -60,7 +69,7 @@ def test_get_top_5_memecoins(auto_tx):

for token in tokens_in_network[:5]:
symbol: str = token["symbol"]
assert symbol.lower() in result.chat_history_json.lower()
assert symbol.lower() in "\n".join(result.info_messages).lower()

def test_get_top_5_memecoins_in_optimism(auto_tx):
tokens = get_coingecko().coins.get_markets(vs_currency="usd", category="meme-token")
Expand All @@ -71,4 +80,4 @@ def test_get_top_5_memecoins_in_optimism(auto_tx):
tokens = filter_token_list_by_network(tokens, ChainId.OPTIMISM.name)
for token in tokens[:5]:
symbol: str = token["symbol"]
assert symbol.lower() in result.chat_history_json.lower()
assert symbol.lower() in "\n".join(result.info_messages).lower()
Loading

0 comments on commit 380d828

Please sign in to comment.