Skip to content

Commit

Permalink
Merge pull request #259 from polywrap/nerfzael/intents
Browse files Browse the repository at this point in the history
Intents
  • Loading branch information
cbrzn authored Jun 17, 2024
2 parents 9984baa + 43c18ab commit d4663b0
Show file tree
Hide file tree
Showing 19 changed files with 540 additions and 213 deletions.
37 changes: 19 additions & 18 deletions autotx/AutoTx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from autotx import models
from autotx.autotx_agent import AutoTxAgent
from autotx.helper_agents import clarifier, manager, user_proxy
from autotx.intents import Intent
from autotx.utils.color import Color
from autotx.utils.logging.Logger import Logger
from autotx.utils.ethereum.networks import NetworkInfo
Expand All @@ -38,7 +39,7 @@ def __init__(self, verbose: bool, get_llm_config: Callable[[], Optional[Dict[str
@dataclass
class PastRun:
feedback: str
transactions_info: str
intents_info: str

class EndReason(Enum):
TERMINATE = "TERMINATE"
Expand All @@ -48,7 +49,7 @@ class EndReason(Enum):
class RunResult:
summary: str
chat_history_json: str
transactions: list[models.Transaction]
intents: list[Intent]
end_reason: EndReason
total_cost_without_cache: float
total_cost_with_cache: float
Expand All @@ -58,7 +59,7 @@ class AutoTx:
web3: Web3
wallet: SmartWallet
logger: Logger
transactions: list[models.Transaction]
intents: list[Intent]
network: NetworkInfo
get_llm_config: Callable[[], Optional[Dict[str, Any]]]
agents: list[AutoTxAgent]
Expand Down Expand Up @@ -91,7 +92,7 @@ def __init__(
self.max_rounds = config.max_rounds
self.verbose = config.verbose
self.get_llm_config = config.get_llm_config
self.transactions = []
self.intents = []
self.current_run_cost_without_cache = 0
self.current_run_cost_with_cache = 0
self.info_messages = []
Expand Down Expand Up @@ -128,7 +129,7 @@ async def a_run(self, prompt: str, non_interactive: bool, summary_method: str =
return RunResult(
result.summary,
result.chat_history_json,
result.transactions,
result.intents,
result.end_reason,
total_cost_without_cache,
total_cost_with_cache,
Expand All @@ -152,13 +153,13 @@ async def try_run(self, prompt: str, non_interactive: bool, summary_method: str

while True:
if past_runs:
self.transactions.clear()
self.intents.clear()

prev_history = "".join(
[
dedent(f"""
Then you prepared these transactions to accomplish the goal:
{run.transactions_info}
{run.intents_info}
Then the user provided feedback:
{run.feedback}
""")
Expand Down Expand Up @@ -208,17 +209,17 @@ async def try_run(self, prompt: str, non_interactive: bool, summary_method: str
is_goal_supported = chat.chat_history[-1]["content"] != "Goal not supported: TERMINATE"

try:
result = self.wallet.on_transactions_ready(self.transactions)
result = self.wallet.on_intents_ready(self.intents)

if isinstance(result, str):
transactions_info ="\n".join(
intents_info ="\n".join(
[
f"{i + 1}. {tx.summary}"
for i, tx in enumerate(self.transactions)
for i, tx in enumerate(self.intents)
]
)

past_runs.append(PastRun(result, transactions_info))
past_runs.append(PastRun(result, intents_info))
else:
break

Expand All @@ -228,17 +229,17 @@ async def try_run(self, prompt: str, non_interactive: bool, summary_method: str

self.logger.stop()

# Copy transactions to a new list to avoid modifying the original list
transactions = self.transactions.copy()
self.transactions.clear()
# Copy intents to a new list to avoid modifying the original list
intents = self.intents.copy()
self.intents.clear()

chat_history = json.dumps(chat.chat_history, indent=4)

return RunResult(chat.summary, chat_history, transactions, EndReason.TERMINATE if is_goal_supported else EndReason.GOAL_NOT_SUPPORTED, float(chat.cost["usage_including_cached_inference"]["total_cost"]), float(chat.cost["usage_excluding_cached_inference"]["total_cost"]), self.info_messages)
return RunResult(chat.summary, chat_history, intents, EndReason.TERMINATE if is_goal_supported else EndReason.GOAL_NOT_SUPPORTED, float(chat.cost["usage_including_cached_inference"]["total_cost"]), float(chat.cost["usage_excluding_cached_inference"]["total_cost"]), self.info_messages)

def add_transactions(self, txs: list[models.Transaction]) -> None:
self.transactions.extend(txs)
self.wallet.on_transactions_prepared(txs)
def add_intents(self, intents: list[Intent]) -> None:
self.intents.extend(intents)
self.wallet.on_intents_prepared(intents)

def notify_user(self, message: str, color: Color | None = None) -> None:
if color:
Expand Down
4 changes: 2 additions & 2 deletions autotx/agents/ExampleAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def run(
# TODO: do something useful
autotx.notify_user(f"ExampleTool run: {amount} {receiver}")

# NOTE: you can add transactions to AutoTx's current bundle
# autotx.transactions.append(tx)
# NOTE: you can add intents to AutoTx's current bundle
# autotx.intents.append(tx)

return f"Something useful has been done with {amount} to {receiver}"

Expand Down
29 changes: 9 additions & 20 deletions autotx/agents/SendTokensAgent.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from textwrap import dedent
from typing import Annotated, Any, Callable, cast
from typing import Annotated, Any, Callable

from web3 import Web3
from autotx import models
from autotx.AutoTx import AutoTx
from autotx.autotx_agent import AutoTxAgent
from autotx.autotx_tool import AutoTxTool
from autotx.intents import SendIntent
from autotx.token import Token
from autotx.utils.ethereum import (
build_transfer_erc20,
get_erc20_balance,
)
from autotx.utils.ethereum.build_transfer_native import build_transfer_native
from web3.constants import ADDRESS_ZERO
from autotx.utils.ethereum.constants import NATIVE_TOKEN_ADDRESS
from autotx.utils.ethereum.eth_address import ETHAddress
from autotx.utils.ethereum.get_native_balance import get_native_balance
Expand Down Expand Up @@ -89,26 +87,17 @@ def run(
receiver_addr = ETHAddress(receiver)
token_address = ETHAddress(autotx.network.tokens[token.lower()])

tx: TxParams

if token_address.hex == NATIVE_TOKEN_ADDRESS:
tx = build_transfer_native(autotx.web3, ETHAddress(ADDRESS_ZERO), receiver_addr, amount)
else:
tx = build_transfer_erc20(autotx.web3, token_address, receiver_addr, amount)

prepared_tx = models.SendTransaction.create(
token_symbol=token,
token_address=str(token_address),
intent = SendIntent.create(
token=Token(symbol=token, address=token_address.hex),
amount=amount,
receiver=str(receiver_addr),
params=cast(dict[str, Any], tx),
receiver=receiver_addr
)

autotx.add_transactions([prepared_tx])
autotx.add_intents([intent])

autotx.notify_user(f"Prepared transaction: {prepared_tx.summary}")
autotx.notify_user(f"Prepared transaction: {intent.summary}")

return prepared_tx.summary
return intent.summary

return run

Expand Down
40 changes: 26 additions & 14 deletions autotx/agents/SwapTokensAgent.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from decimal import Decimal
from textwrap import dedent
from typing import Annotated, Callable
from autotx import models
from autotx.AutoTx import AutoTx
from autotx.autotx_agent import AutoTxAgent
from autotx.autotx_tool import AutoTxTool
from autotx.intents import BuyIntent, Intent, SellIntent
from autotx.token import Token
from autotx.utils.ethereum.eth_address import ETHAddress
from autotx.utils.ethereum.lifi.swap import SUPPORTED_NETWORKS_BY_LIFI, build_swap_transaction
from autotx.utils.ethereum.lifi.swap import SUPPORTED_NETWORKS_BY_LIFI, can_build_swap_transaction
from autotx.utils.ethereum.networks import NetworkInfo
from gnosis.eth import EthereumNetworkNotSupported as ChainIdNotSupported

Expand Down Expand Up @@ -80,7 +81,7 @@ def get_tokens_address(token_in: str, token_out: str, network_info: NetworkInfo)
class InvalidInput(Exception):
pass

def swap(autotx: AutoTx, token_to_sell: str, token_to_buy: str) -> list[models.Transaction]:
def swap(autotx: AutoTx, token_to_sell: str, token_to_buy: str) -> Intent:
sell_parts = token_to_sell.split(" ")
buy_parts = token_to_buy.split(" ")

Expand Down Expand Up @@ -117,7 +118,7 @@ def swap(autotx: AutoTx, token_to_sell: str, token_to_buy: str) -> list[models.T
token_in, token_out, autotx.network
)

swap_transactions = build_swap_transaction(
can_build_swap_transaction(
autotx.web3,
Decimal(exact_amount),
ETHAddress(token_in_address),
Expand All @@ -126,9 +127,21 @@ def swap(autotx: AutoTx, token_to_sell: str, token_to_buy: str) -> list[models.T
is_exact_input,
autotx.network.chain_id
)
autotx.add_transactions(swap_transactions)

return swap_transactions
# Create buy intent if amount of token to buy is provided else create sell intent
swap_intent: Intent = BuyIntent.create(
from_token=Token(symbol=token_symbol_to_sell, address=token_in_address),
to_token=Token(symbol=token_symbol_to_buy, address=token_out_address),
amount=float(exact_amount),
) if len(buy_parts) == 2 else SellIntent.create(
from_token=Token(symbol=token_symbol_to_sell, address=token_in_address),
to_token=Token(symbol=token_symbol_to_buy, address=token_out_address),
amount=float(exact_amount),
)

autotx.add_intents([swap_intent])

return swap_intent

class BulkSwapTool(AutoTxTool):
name: str = "prepare_bulk_swap_transactions"
Expand All @@ -152,36 +165,35 @@ def run(
],
) -> str:
swaps = tokens.split("\n")
all_txs = []
all_intents = []
all_errors: list[Exception] = []

for swap_str in swaps:
(token_to_sell, token_to_buy) = swap_str.strip().split(" to ")
try:
txs = swap(autotx, token_to_sell, token_to_buy)
all_txs.extend(txs)
all_intents.append(swap(autotx, token_to_sell, token_to_buy))
except InvalidInput as e:
all_errors.append(e)
except Exception as e:
all_errors.append(Exception(f"Error: {e} for swap \"{token_to_sell} to {token_to_buy}\""))


summary = "".join(
f"Prepared transaction: {swap_transaction.summary}\n"
for swap_transaction in all_txs
f"Prepared transaction: {intent.summary}\n"
for intent in all_intents
)

if all_errors:
summary += "\n".join(str(e) for e in all_errors)
if len(all_txs) > 0:
summary += f"\n{len(all_errors)} errors occurred. {len(all_txs)} transactions were prepared. There is no need to re-run the transactions that were prepared."
if len(all_intents) > 0:
summary += f"\n{len(all_errors)} errors occurred. {len(all_intents)} transactions were prepared. There is no need to re-run the transactions that were prepared."
else:
summary += f"\n{len(all_errors)} errors occurred."

total_summary = ("\n" + " " * 16).join(
[
f"{i + 1}. {tx.summary}"
for i, tx in enumerate(autotx.transactions)
for i, tx in enumerate(autotx.intents)
]
)

Expand Down
25 changes: 16 additions & 9 deletions autotx/db.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from datetime import datetime
import json
import os
from typing import Any
import uuid
from pydantic import BaseModel
from supabase import create_client
from supabase.client import Client
from supabase.lib.client_options import ClientOptions

from autotx import models
from autotx.transactions import Transaction

SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
Expand Down Expand Up @@ -51,7 +53,7 @@ def start(self, prompt: str, address: str, chain_id: int, app_user_id: str) -> m
"created_at": str(created_at),
"updated_at": str(updated_at),
"messages": json.dumps([]),
"transactions": json.dumps([])
"intents": json.dumps([])
}
).execute()

Expand All @@ -65,7 +67,7 @@ def start(self, prompt: str, address: str, chain_id: int, app_user_id: str) -> m
running=True,
error=None,
messages=[],
transactions=[]
intents=[]
)

def stop(self, task_id: str) -> None:
Expand All @@ -81,7 +83,7 @@ def stop(self, task_id: str) -> None:
def update(self, task: models.Task) -> None:
client = get_db_client("public")

txs = [json.loads(tx.json()) for tx in task.transactions]
intents = [json.loads(intent.json()) for intent in task.intents]

client.table("tasks").update(
{
Expand All @@ -90,7 +92,7 @@ def update(self, task: models.Task) -> None:
"updated_at": str(datetime.utcnow()),
"messages": json.dumps(task.messages),
"error": task.error,
"transactions": json.dumps(txs)
"intents": json.dumps(intents)
}
).eq("id", task.id).eq("app_id", self.app_id).execute()

Expand Down Expand Up @@ -118,7 +120,7 @@ def get(self, task_id: str) -> models.Task | None:
running=task_data["running"],
error=task_data["error"],
messages=json.loads(task_data["messages"]),
transactions=json.loads(task_data["transactions"])
intents=json.loads(task_data["intents"])
)

def get_all(self) -> list[models.Task]:
Expand All @@ -140,7 +142,7 @@ def get_all(self) -> list[models.Task]:
running=task_data["running"],
error=task_data["error"],
messages=json.loads(task_data["messages"]),
transactions=json.loads(task_data["transactions"])
intents=json.loads(task_data["intents"])
)
)

Expand Down Expand Up @@ -222,9 +224,11 @@ def get_agent_private_key(app_id: str, user_id: str) -> str | None:

return str(result.data[0]["agent_private_key"])

def submit_transactions(app_id: str, address: str, chain_id: int, app_user_id: str, task_id: str) -> None:
def submit_transactions(app_id: str, address: str, chain_id: int, app_user_id: str, task_id: str, transactions: list[Transaction]) -> None:
client = get_db_client("public")

txs = [json.loads(tx.json()) for tx in transactions]

created_at = datetime.utcnow()
client.table("submitted_batches") \
.insert(
Expand All @@ -234,7 +238,8 @@ def submit_transactions(app_id: str, address: str, chain_id: int, app_user_id: s
"chain_id": chain_id,
"app_user_id": app_user_id,
"task_id": task_id,
"created_at": str(created_at)
"created_at": str(created_at),
"transactions": json.dumps(txs)
}
).execute()

Expand All @@ -246,6 +251,7 @@ class SubmittedBatch(BaseModel):
app_user_id: str
task_id: str
created_at: datetime
transactions: list[dict[str, Any]]

def get_submitted_batches(app_id: str, task_id: str) -> list[SubmittedBatch]:
client = get_db_client("public")
Expand All @@ -267,7 +273,8 @@ def get_submitted_batches(app_id: str, task_id: str) -> list[SubmittedBatch]:
chain_id=batch_data["chain_id"],
app_user_id=batch_data["app_user_id"],
task_id=batch_data["task_id"],
created_at=batch_data["created_at"]
created_at=batch_data["created_at"],
transactions=json.loads(batch_data["transactions"])
)
)

Expand Down
Loading

0 comments on commit d4663b0

Please sign in to comment.