From 08cd799a1549c8ea397f01d38e2b4d9614a70b73 Mon Sep 17 00:00:00 2001 From: nerfZael Date: Mon, 17 Jun 2024 21:26:51 +0200 Subject: [PATCH 1/2] preparing transactions before submitting --- autotx/db.py | 44 ++++++- autotx/server.py | 112 ++++++++++++------ autotx/tests/api/test_send_transactions.py | 86 +++++++++++++- .../20240617165316_submitting-txs.sql | 3 + 4 files changed, 204 insertions(+), 41 deletions(-) create mode 100644 supabase/migrations/20240617165316_submitting-txs.sql diff --git a/autotx/db.py b/autotx/db.py index 8dd73f6..c4a2738 100644 --- a/autotx/db.py +++ b/autotx/db.py @@ -9,7 +9,7 @@ from supabase.lib.client_options import ClientOptions from autotx import models -from autotx.transactions import Transaction +from autotx.transactions import Transaction, TransactionBase SUPABASE_URL = os.getenv("SUPABASE_URL") SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY") @@ -224,13 +224,13 @@ def get_agent_private_key(app_id: str, user_id: str) -> str | None: return str(result.data[0]["agent_private_key"]) -def submit_transactions(app_id: str, address: str, chain_id: int, app_user_id: str, task_id: str, transactions: list[Transaction]) -> None: +def save_transactions(app_id: str, address: str, chain_id: int, app_user_id: str, task_id: str, transactions: list[Transaction]) -> str: client = get_db_client("public") txs = [json.loads(tx.json()) for tx in transactions] created_at = datetime.utcnow() - client.table("submitted_batches") \ + result = client.table("submitted_batches") \ .insert( { "app_id": app_id, @@ -243,6 +243,42 @@ def submit_transactions(app_id: str, address: str, chain_id: int, app_user_id: s } ).execute() + return result.data[0]["id"] + +def get_transactions(app_id: str, app_user_id: str, task_id: str, address: str, chain_id: str, submitted_batch_id: str) -> tuple[list[TransactionBase], str] | None: + client = get_db_client("public") + + result = client.table("submitted_batches") \ + .select("transactions, task_id") \ + .eq("app_id", app_id) \ + .eq("app_user_id", app_user_id) \ + .eq("address", address) \ + .eq("chain_id", chain_id) \ + .eq("task_id", task_id) \ + .eq("id", submitted_batch_id) \ + .execute() + + if len(result.data) == 0: + return None + + return ( + [TransactionBase(**tx) for tx in json.loads(result.data[0]["transactions"])], + result.data[0]["task_id"] + ) + +def submit_transactions(app_id: str, app_user_id: str, submitted_batch_id: str) -> None: + client = get_db_client("public") + + client.table("submitted_batches") \ + .update( + { + "submitted_on": str(datetime.utcnow()) + } + ).eq("app_id", app_id) \ + .eq("app_user_id", app_user_id) \ + .eq("id", submitted_batch_id) \ + .execute() + class SubmittedBatch(BaseModel): id: str app_id: str @@ -251,6 +287,7 @@ class SubmittedBatch(BaseModel): app_user_id: str task_id: str created_at: datetime + submitted_on: datetime | None transactions: list[dict[str, Any]] def get_submitted_batches(app_id: str, task_id: str) -> list[SubmittedBatch]: @@ -274,6 +311,7 @@ def get_submitted_batches(app_id: str, task_id: str) -> list[SubmittedBatch]: app_user_id=batch_data["app_user_id"], task_id=batch_data["task_id"], created_at=batch_data["created_at"], + submitted_on=batch_data["submitted_on"], transactions=json.loads(batch_data["transactions"]) ) ) diff --git a/autotx/server.py b/autotx/server.py index c23f06f..51485eb 100644 --- a/autotx/server.py +++ b/autotx/server.py @@ -5,6 +5,7 @@ from fastapi import APIRouter, FastAPI, BackgroundTasks, HTTPException, Header from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse +from pydantic import BaseModel import traceback from autotx import models, setup @@ -60,6 +61,43 @@ def authorize(authorization: str | None) -> models.App: return app +def load_config_for_user(app_id: str, user_id: str, address: str, chain_id: int) -> AppConfig: + agent_private_key = db.get_agent_private_key(app_id, user_id) + + if not agent_private_key: + raise HTTPException(status_code=400, detail="User not found") + + agent = Account.from_key(agent_private_key) + + app_config = AppConfig.load(smart_account_addr=address, subsidized_chain_id=chain_id, agent=agent) + + return app_config + +def authorize_app_and_user(authorization: str | None, user_id: str) -> tuple[models.App, models.AppUser]: + app = authorize(authorization) + app_user = db.get_app_user(app.id, user_id) + + if not app_user: + raise HTTPException(status_code=400, detail="User not found") + + return (app, app_user) + +def build_transactions(app_id: str, user_id: str, chain_id: int, address: str, task: models.Task) -> List[Transaction]: + if task.running: + raise HTTPException(status_code=400, detail="Task is still running") + + app_config = load_config_for_user(app_id, user_id, address, chain_id) + + if task.intents is None or len(task.intents) == 0: + return ([], app_config) + + transactions: list[Transaction] = [] + + for intent in task.intents: + transactions.extend(intent.build_transactions(app_config.web3, app_config.network_info, app_config.manager.address)) + + return transactions + @app_router.post("/api/v1/tasks", response_model=models.Task) async def create_task(task: models.TaskCreate, background_tasks: BackgroundTasks, authorization: Annotated[str | None, Header()] = None) -> models.Task: app = authorize(authorization) @@ -172,46 +210,39 @@ def get_intents(task_id: str, authorization: Annotated[str | None, Header()] = N task = get_task_or_404(task_id, tasks) return task.intents -def authorize_app_and_user(authorization: str | None, user_id: str) -> tuple[models.App, models.AppUser]: - app = authorize(authorization) - app_user = db.get_app_user(app.id, user_id) - - if not app_user: - raise HTTPException(status_code=400, detail="User not found") - - return (app, app_user) - -def build_transactions(app_id: str, user_id: str, chain_id: int, address: str, task: models.Task) -> tuple[List[Transaction], AppConfig]: - if task.running: - raise HTTPException(status_code=400, detail="Task is still running") - - agent_private_key = db.get_agent_private_key(app_id, user_id) - - if not agent_private_key: - raise HTTPException(status_code=400, detail="User not found") - - agent = Account.from_key(agent_private_key) +@app_router.get("/api/v1/tasks/{task_id}/transactions", response_model=List[Transaction]) +def get_transactions( + task_id: str, + address: str, + chain_id: int, + user_id: str, + authorization: Annotated[str | None, Header()] = None +) -> List[Transaction]: + (app, app_user) = authorize_app_and_user(authorization, user_id) - app_config = AppConfig.load(smart_account_addr=address, subsidized_chain_id=chain_id, agent=agent) + tasks = db.TasksRepository(app.id) + + task = get_task_or_404(task_id, tasks) - if task.intents is None or len(task.intents) == 0: - return ([], app_config) + if task.chain_id != chain_id: + raise HTTPException(status_code=400, detail="Chain ID does not match task") - transactions: list[Transaction] = [] + transactions = build_transactions(app.id, user_id, chain_id, address, task) - for intent in task.intents: - transactions.extend(intent.build_transactions(app_config.web3, app_config.network_info, app_config.manager.address)) + return transactions - return (transactions, app_config) +class PreparedTransactionsDto(BaseModel): + batch_id: str + transactions: List[Transaction] -@app_router.get("/api/v1/tasks/{task_id}/transactions", response_model=List[Transaction]) -def get_transactions( +@app_router.post("/api/v1/tasks/{task_id}/transactions/prepare", response_model=PreparedTransactionsDto) +def prepare_transactions( task_id: str, address: str, chain_id: int, user_id: str, authorization: Annotated[str | None, Header()] = None -) -> List[Transaction]: +) -> str: (app, app_user) = authorize_app_and_user(authorization, user_id) tasks = db.TasksRepository(app.id) @@ -220,10 +251,15 @@ def get_transactions( if task.chain_id != chain_id: raise HTTPException(status_code=400, detail="Chain ID does not match task") + + transactions = build_transactions(app.id, app_user.user_id, chain_id, address, task) + + if len(transactions) == 0: + raise HTTPException(status_code=400, detail="No transactions to send") - (transactions, _) = build_transactions(app.id, app_user.user_id, chain_id, address, task) + submitted_batch_id = db.save_transactions(app.id, address, chain_id, app_user.id, task_id, transactions) - return transactions + return PreparedTransactionsDto(batch_id=submitted_batch_id, transactions=transactions) @app_router.post("/api/v1/tasks/{task_id}/transactions") def send_transactions( @@ -231,6 +267,7 @@ def send_transactions( address: str, chain_id: int, user_id: str, + batch_id: str, authorization: Annotated[str | None, Header()] = None ) -> str: (app, app_user) = authorize_app_and_user(authorization, user_id) @@ -242,7 +279,12 @@ def send_transactions( if task.chain_id != chain_id: raise HTTPException(status_code=400, detail="Chain ID does not match task") - (transactions, app_config) = build_transactions(app.id, app_user.user_id, chain_id, address, task) + batch = db.get_transactions(app.id, app_user.id, task_id, address, chain_id, batch_id) + + if batch is None: + raise HTTPException(status_code=400, detail="Batch not found") + + (transactions, task_id) = batch if len(transactions) == 0: raise HTTPException(status_code=400, detail="No transactions to send") @@ -250,10 +292,12 @@ def send_transactions( global autotx_params if autotx_params.is_dev: print("Dev mode: skipping transaction submission") - db.submit_transactions(app.id, address, chain_id, app_user.id, task_id, transactions) + db.submit_transactions(app.id, app_user.id, batch_id) return f"https://app.safe.global/transactions/queue?safe={CHAIN_ID_TO_SHORT_NAME[str(chain_id)]}:{address}" try: + app_config = load_config_for_user(app.id, user_id, address, chain_id) + app_config.manager.send_multisend_tx_batch( transactions, require_approval=False, @@ -264,7 +308,7 @@ def send_transactions( else: raise e - db.submit_transactions(app.id, address, chain_id, app_user.id, task_id, transactions) + db.submit_transactions(app.id, app_user.id, batch_id) return f"https://app.safe.global/transactions/queue?safe={CHAIN_ID_TO_SHORT_NAME[str(chain_id)]}:{address}" diff --git a/autotx/tests/api/test_send_transactions.py b/autotx/tests/api/test_send_transactions.py index b6a8e35..19fbbbf 100644 --- a/autotx/tests/api/test_send_transactions.py +++ b/autotx/tests/api/test_send_transactions.py @@ -23,6 +23,17 @@ def test_get_transactions_auth(): }) assert response.status_code == 401 +def test_prepare_transactions_auth(): + + user_id = uuid.uuid4().hex + + response = client.post("/api/v1/tasks/123/transactions/prepare", params={ + "user_id": user_id, + "address": "0x123", + "chain_id": 1, + }) + assert response.status_code == 401 + def test_send_transactions_auth(): user_id = uuid.uuid4().hex @@ -31,6 +42,7 @@ def test_send_transactions_auth(): "user_id": user_id, "address": "0x123", "chain_id": 1, + "batch_id": "123" }) assert response.status_code == 401 @@ -122,7 +134,7 @@ def test_send_transactions(): task_id = data["id"] - response = client.post(f"/api/v1/tasks/{task_id}/transactions", params={ + response = client.post(f"/api/v1/tasks/{task_id}/transactions/prepare", params={ "user_id": user_id, "address": smart_wallet_address, "chain_id": 2, @@ -131,7 +143,17 @@ def test_send_transactions(): }) assert response.status_code == 400 - response = client.post(f"/api/v1/tasks/{task_id}/transactions", params={ + response = client.post(f"/api/v1/tasks/{task_id}/transactions/prepare", params={ + "user_id": user_id, + "address": smart_wallet_address, + "chain_id": 1, + }, headers={ + "Authorization": f"Bearer 1234" + }) + assert response.status_code == 200 + batch1 = response.json() + + response = client.post(f"/api/v1/tasks/{task_id}/transactions/prepare", params={ "user_id": user_id, "address": smart_wallet_address, "chain_id": 1, @@ -139,14 +161,70 @@ def test_send_transactions(): "Authorization": f"Bearer 1234" }) assert response.status_code == 200 + batch2 = response.json() app = db.get_app_by_api_key("1234") batches = db.get_submitted_batches(app.id, task_id) - - assert len(batches) == 1 + assert len(batches) == 2 + assert batches[0].app_id == app.id assert batches[0].address == smart_wallet_address assert batches[0].chain_id == 1 assert batches[0].task_id == task_id assert batches[0].created_at is not None + assert batches[0].submitted_on is None + + assert batches[1].app_id == app.id + assert batches[1].address == smart_wallet_address + assert batches[1].chain_id == 1 + assert batches[1].task_id == task_id + assert batches[1].created_at is not None + assert batches[1].submitted_on is None + + assert batch1["batch_id"] == batches[0].id + assert len(batch1["transactions"]) == 1 + + assert batch2["batch_id"] == batches[1].id + assert len(batch2["transactions"]) == 1 + + response = client.post(f"/api/v1/tasks/{task_id}/transactions", params={ + "user_id": user_id, + "address": smart_wallet_address, + "chain_id": 2, + "batch_id": batch1["batch_id"] + }, headers={ + "Authorization": f"Bearer 1234" + }) + assert response.status_code == 400 + + response = client.post(f"/api/v1/tasks/{task_id}/transactions", params={ + "user_id": user_id, + "address": smart_wallet_address, + "chain_id": 1, + "batch_id": batch1["batch_id"] + }, headers={ + "Authorization": f"Bearer 1234" + }) + assert response.status_code == 200 + + batches = db.get_submitted_batches(app.id, task_id) + batches = sorted(batches, key=lambda x: x.created_at) + assert len(batches) == 2 + assert batches[0].submitted_on is not None + assert batches[1].submitted_on is None + + response = client.post(f"/api/v1/tasks/{task_id}/transactions", params={ + "user_id": user_id, + "address": smart_wallet_address, + "chain_id": 1, + "batch_id": batch2["batch_id"] + }, headers={ + "Authorization": f"Bearer 1234" + }) + assert response.status_code == 200 + + batches = db.get_submitted_batches(app.id, task_id) + assert len(batches) == 2 + assert batches[0].submitted_on is not None + assert batches[1].submitted_on is not None diff --git a/supabase/migrations/20240617165316_submitting-txs.sql b/supabase/migrations/20240617165316_submitting-txs.sql new file mode 100644 index 0000000..f56fed8 --- /dev/null +++ b/supabase/migrations/20240617165316_submitting-txs.sql @@ -0,0 +1,3 @@ +alter table "public"."submitted_batches" add column "submitted_on" timestamp with time zone; + + From f0c0e7acaa0da73e828bdcd0910d508aa433b4a5 Mon Sep 17 00:00:00 2001 From: nerfZael Date: Mon, 17 Jun 2024 21:33:03 +0200 Subject: [PATCH 2/2] type fixes --- autotx/db.py | 6 +++--- autotx/server.py | 4 ++-- autotx/utils/ethereum/SafeManager.py | 7 +++---- autotx/wallets/safe_smart_wallet.py | 3 ++- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/autotx/db.py b/autotx/db.py index c4a2738..2ea56ad 100644 --- a/autotx/db.py +++ b/autotx/db.py @@ -1,7 +1,7 @@ from datetime import datetime import json import os -from typing import Any +from typing import Any, cast import uuid from pydantic import BaseModel from supabase import create_client @@ -243,9 +243,9 @@ def save_transactions(app_id: str, address: str, chain_id: int, app_user_id: str } ).execute() - return result.data[0]["id"] + return cast(str, result.data[0]["id"]) -def get_transactions(app_id: str, app_user_id: str, task_id: str, address: str, chain_id: str, submitted_batch_id: str) -> tuple[list[TransactionBase], str] | None: +def get_transactions(app_id: str, app_user_id: str, task_id: str, address: str, chain_id: int, submitted_batch_id: str) -> tuple[list[TransactionBase], str] | None: client = get_db_client("public") result = client.table("submitted_batches") \ diff --git a/autotx/server.py b/autotx/server.py index 51485eb..c0f10ab 100644 --- a/autotx/server.py +++ b/autotx/server.py @@ -89,7 +89,7 @@ def build_transactions(app_id: str, user_id: str, chain_id: int, address: str, t app_config = load_config_for_user(app_id, user_id, address, chain_id) if task.intents is None or len(task.intents) == 0: - return ([], app_config) + return [] transactions: list[Transaction] = [] @@ -242,7 +242,7 @@ def prepare_transactions( chain_id: int, user_id: str, authorization: Annotated[str | None, Header()] = None -) -> str: +) -> PreparedTransactionsDto: (app, app_user) = authorize_app_and_user(authorization, user_id) tasks = db.TasksRepository(app.id) diff --git a/autotx/utils/ethereum/SafeManager.py b/autotx/utils/ethereum/SafeManager.py index 77177e2..9072532 100644 --- a/autotx/utils/ethereum/SafeManager.py +++ b/autotx/utils/ethereum/SafeManager.py @@ -14,8 +14,7 @@ from gnosis.safe.api import TransactionServiceApi from eth_account.signers.local import LocalAccount -from autotx import models -from autotx.transactions import Transaction +from autotx.transactions import TransactionBase from autotx.utils.ethereum.get_native_balance import get_native_balance from autotx.utils.ethereum.cached_safe_address import get_cached_safe_address, save_cached_safe_address from autotx.utils.ethereum.eth_address import ETHAddress @@ -242,7 +241,7 @@ def send_multisend_tx(self, txs: list[TxParams | dict[str, Any]], safe_nonce: Op hash = self.execute_multisend_tx(txs, safe_nonce) return hash.hex() - def send_tx_batch(self, txs: list[Transaction], require_approval: bool, safe_nonce: Optional[int] = None) -> bool | str: # True if sent, False if declined, str if feedback + def send_tx_batch(self, txs: list[TransactionBase], require_approval: bool, safe_nonce: Optional[int] = None) -> bool | str: # True if sent, False if declined, str if feedback print("=" * 50) if not txs: @@ -316,7 +315,7 @@ def send_tx_batch(self, txs: list[Transaction], require_approval: bool, safe_non return True - def send_multisend_tx_batch(self, txs: list[Transaction], require_approval: bool, safe_nonce: Optional[int] = None) -> bool | str: # True if sent, False if declined, str if feedback + def send_multisend_tx_batch(self, txs: list[TransactionBase], require_approval: bool, safe_nonce: Optional[int] = None) -> bool | str: # True if sent, False if declined, str if feedback print("=" * 50) if not txs: diff --git a/autotx/wallets/safe_smart_wallet.py b/autotx/wallets/safe_smart_wallet.py index 8b7b147..c3db603 100644 --- a/autotx/wallets/safe_smart_wallet.py +++ b/autotx/wallets/safe_smart_wallet.py @@ -1,4 +1,5 @@ from autotx.intents import Intent +from autotx.transactions import TransactionBase from autotx.utils.ethereum import SafeManager from autotx.wallets.smart_wallet import SmartWallet @@ -17,7 +18,7 @@ def on_intents_prepared(self, intents: list[Intent]) -> None: pass def on_intents_ready(self, intents: list[Intent]) -> bool | str: - transactions = [] + transactions: list[TransactionBase] = [] for intent in intents: transactions.extend(intent.build_transactions(self.manager.web3, self.manager.network, self.manager.address))