Skip to content

Commit

Permalink
Merge pull request #261 from polywrap/nerfzael/prepare-txs
Browse files Browse the repository at this point in the history
Prepare transactions before submitting
  • Loading branch information
nerfZael authored Jun 17, 2024
2 parents d4663b0 + f0c0e7a commit 17ad64d
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 47 deletions.
46 changes: 42 additions & 4 deletions autotx/db.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from datetime import datetime
import json
import os
from typing import Any
from typing import Any, cast
import uuid
from pydantic import BaseModel
from supabase import create_client
from supabase.client import Client
from supabase.lib.client_options import ClientOptions

from autotx import models
from autotx.transactions import Transaction
from autotx.transactions import Transaction, TransactionBase

SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
Expand Down Expand Up @@ -224,13 +224,13 @@ def get_agent_private_key(app_id: str, user_id: str) -> str | None:

return str(result.data[0]["agent_private_key"])

def submit_transactions(app_id: str, address: str, chain_id: int, app_user_id: str, task_id: str, transactions: list[Transaction]) -> None:
def save_transactions(app_id: str, address: str, chain_id: int, app_user_id: str, task_id: str, transactions: list[Transaction]) -> str:
client = get_db_client("public")

txs = [json.loads(tx.json()) for tx in transactions]

created_at = datetime.utcnow()
client.table("submitted_batches") \
result = client.table("submitted_batches") \
.insert(
{
"app_id": app_id,
Expand All @@ -243,6 +243,42 @@ def submit_transactions(app_id: str, address: str, chain_id: int, app_user_id: s
}
).execute()

return cast(str, result.data[0]["id"])

def get_transactions(app_id: str, app_user_id: str, task_id: str, address: str, chain_id: int, submitted_batch_id: str) -> tuple[list[TransactionBase], str] | None:
client = get_db_client("public")

result = client.table("submitted_batches") \
.select("transactions, task_id") \
.eq("app_id", app_id) \
.eq("app_user_id", app_user_id) \
.eq("address", address) \
.eq("chain_id", chain_id) \
.eq("task_id", task_id) \
.eq("id", submitted_batch_id) \
.execute()

if len(result.data) == 0:
return None

return (
[TransactionBase(**tx) for tx in json.loads(result.data[0]["transactions"])],
result.data[0]["task_id"]
)

def submit_transactions(app_id: str, app_user_id: str, submitted_batch_id: str) -> None:
client = get_db_client("public")

client.table("submitted_batches") \
.update(
{
"submitted_on": str(datetime.utcnow())
}
).eq("app_id", app_id) \
.eq("app_user_id", app_user_id) \
.eq("id", submitted_batch_id) \
.execute()

class SubmittedBatch(BaseModel):
id: str
app_id: str
Expand All @@ -251,6 +287,7 @@ class SubmittedBatch(BaseModel):
app_user_id: str
task_id: str
created_at: datetime
submitted_on: datetime | None
transactions: list[dict[str, Any]]

def get_submitted_batches(app_id: str, task_id: str) -> list[SubmittedBatch]:
Expand All @@ -274,6 +311,7 @@ def get_submitted_batches(app_id: str, task_id: str) -> list[SubmittedBatch]:
app_user_id=batch_data["app_user_id"],
task_id=batch_data["task_id"],
created_at=batch_data["created_at"],
submitted_on=batch_data["submitted_on"],
transactions=json.loads(batch_data["transactions"])
)
)
Expand Down
112 changes: 78 additions & 34 deletions autotx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastapi import APIRouter, FastAPI, BackgroundTasks, HTTPException, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import traceback

from autotx import models, setup
Expand Down Expand Up @@ -60,6 +61,43 @@ def authorize(authorization: str | None) -> models.App:

return app

def load_config_for_user(app_id: str, user_id: str, address: str, chain_id: int) -> AppConfig:
agent_private_key = db.get_agent_private_key(app_id, user_id)

if not agent_private_key:
raise HTTPException(status_code=400, detail="User not found")

agent = Account.from_key(agent_private_key)

app_config = AppConfig.load(smart_account_addr=address, subsidized_chain_id=chain_id, agent=agent)

return app_config

def authorize_app_and_user(authorization: str | None, user_id: str) -> tuple[models.App, models.AppUser]:
app = authorize(authorization)
app_user = db.get_app_user(app.id, user_id)

if not app_user:
raise HTTPException(status_code=400, detail="User not found")

return (app, app_user)

def build_transactions(app_id: str, user_id: str, chain_id: int, address: str, task: models.Task) -> List[Transaction]:
if task.running:
raise HTTPException(status_code=400, detail="Task is still running")

app_config = load_config_for_user(app_id, user_id, address, chain_id)

if task.intents is None or len(task.intents) == 0:
return []

transactions: list[Transaction] = []

for intent in task.intents:
transactions.extend(intent.build_transactions(app_config.web3, app_config.network_info, app_config.manager.address))

return transactions

@app_router.post("/api/v1/tasks", response_model=models.Task)
async def create_task(task: models.TaskCreate, background_tasks: BackgroundTasks, authorization: Annotated[str | None, Header()] = None) -> models.Task:
app = authorize(authorization)
Expand Down Expand Up @@ -172,46 +210,39 @@ def get_intents(task_id: str, authorization: Annotated[str | None, Header()] = N
task = get_task_or_404(task_id, tasks)
return task.intents

def authorize_app_and_user(authorization: str | None, user_id: str) -> tuple[models.App, models.AppUser]:
app = authorize(authorization)
app_user = db.get_app_user(app.id, user_id)

if not app_user:
raise HTTPException(status_code=400, detail="User not found")

return (app, app_user)

def build_transactions(app_id: str, user_id: str, chain_id: int, address: str, task: models.Task) -> tuple[List[Transaction], AppConfig]:
if task.running:
raise HTTPException(status_code=400, detail="Task is still running")

agent_private_key = db.get_agent_private_key(app_id, user_id)

if not agent_private_key:
raise HTTPException(status_code=400, detail="User not found")

agent = Account.from_key(agent_private_key)
@app_router.get("/api/v1/tasks/{task_id}/transactions", response_model=List[Transaction])
def get_transactions(
task_id: str,
address: str,
chain_id: int,
user_id: str,
authorization: Annotated[str | None, Header()] = None
) -> List[Transaction]:
(app, app_user) = authorize_app_and_user(authorization, user_id)

app_config = AppConfig.load(smart_account_addr=address, subsidized_chain_id=chain_id, agent=agent)
tasks = db.TasksRepository(app.id)

task = get_task_or_404(task_id, tasks)

if task.intents is None or len(task.intents) == 0:
return ([], app_config)
if task.chain_id != chain_id:
raise HTTPException(status_code=400, detail="Chain ID does not match task")

transactions: list[Transaction] = []
transactions = build_transactions(app.id, user_id, chain_id, address, task)

for intent in task.intents:
transactions.extend(intent.build_transactions(app_config.web3, app_config.network_info, app_config.manager.address))
return transactions

return (transactions, app_config)
class PreparedTransactionsDto(BaseModel):
batch_id: str
transactions: List[Transaction]

@app_router.get("/api/v1/tasks/{task_id}/transactions", response_model=List[Transaction])
def get_transactions(
@app_router.post("/api/v1/tasks/{task_id}/transactions/prepare", response_model=PreparedTransactionsDto)
def prepare_transactions(
task_id: str,
address: str,
chain_id: int,
user_id: str,
authorization: Annotated[str | None, Header()] = None
) -> List[Transaction]:
) -> PreparedTransactionsDto:
(app, app_user) = authorize_app_and_user(authorization, user_id)

tasks = db.TasksRepository(app.id)
Expand All @@ -220,17 +251,23 @@ def get_transactions(

if task.chain_id != chain_id:
raise HTTPException(status_code=400, detail="Chain ID does not match task")

transactions = build_transactions(app.id, app_user.user_id, chain_id, address, task)

if len(transactions) == 0:
raise HTTPException(status_code=400, detail="No transactions to send")

(transactions, _) = build_transactions(app.id, app_user.user_id, chain_id, address, task)
submitted_batch_id = db.save_transactions(app.id, address, chain_id, app_user.id, task_id, transactions)

return transactions
return PreparedTransactionsDto(batch_id=submitted_batch_id, transactions=transactions)

@app_router.post("/api/v1/tasks/{task_id}/transactions")
def send_transactions(
task_id: str,
address: str,
chain_id: int,
user_id: str,
batch_id: str,
authorization: Annotated[str | None, Header()] = None
) -> str:
(app, app_user) = authorize_app_and_user(authorization, user_id)
Expand All @@ -242,18 +279,25 @@ def send_transactions(
if task.chain_id != chain_id:
raise HTTPException(status_code=400, detail="Chain ID does not match task")

(transactions, app_config) = build_transactions(app.id, app_user.user_id, chain_id, address, task)
batch = db.get_transactions(app.id, app_user.id, task_id, address, chain_id, batch_id)

if batch is None:
raise HTTPException(status_code=400, detail="Batch not found")

(transactions, task_id) = batch

if len(transactions) == 0:
raise HTTPException(status_code=400, detail="No transactions to send")

global autotx_params
if autotx_params.is_dev:
print("Dev mode: skipping transaction submission")
db.submit_transactions(app.id, address, chain_id, app_user.id, task_id, transactions)
db.submit_transactions(app.id, app_user.id, batch_id)
return f"https://app.safe.global/transactions/queue?safe={CHAIN_ID_TO_SHORT_NAME[str(chain_id)]}:{address}"

try:
app_config = load_config_for_user(app.id, user_id, address, chain_id)

app_config.manager.send_multisend_tx_batch(
transactions,
require_approval=False,
Expand All @@ -264,7 +308,7 @@ def send_transactions(
else:
raise e

db.submit_transactions(app.id, address, chain_id, app_user.id, task_id, transactions)
db.submit_transactions(app.id, app_user.id, batch_id)

return f"https://app.safe.global/transactions/queue?safe={CHAIN_ID_TO_SHORT_NAME[str(chain_id)]}:{address}"

Expand Down
Loading

0 comments on commit 17ad64d

Please sign in to comment.