Skip to content

Commit

Permalink
feat(api): adding aggregate model
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianHymer committed Jul 26, 2024
1 parent c237c6a commit 034d6d0
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 14 deletions.
55 changes: 42 additions & 13 deletions api/passport/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import json
from typing import Dict
from typing import Dict, List

import aiohttp
import api_logging as logging
Expand All @@ -12,6 +12,7 @@
from ninja_extra.exceptions import APIException
from registry.api.utils import aapi_key, check_rate_limit, is_valid_address
from registry.exceptions import InvalidAddressException
from scorer.settings.model_config import MODEL_AGGREGATION_KEYS

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,17 +88,11 @@ async def fetch(session, url, data):
return await response.json()


async def fetch_all(urls, address):
async def fetch_all(urls, payload):
async with aiohttp.ClientSession() as session:
tasks = []
for url in urls:
task = asyncio.ensure_future(
fetch(
session,
url,
{"address": address},
)
)
task = asyncio.ensure_future(fetch(session, url, payload))
tasks.append(task)
responses = await asyncio.gather(*tasks)
return responses
Expand Down Expand Up @@ -130,13 +125,17 @@ async def handle_get_analysis(
detail=f"Invalid model name(s): {', '.join(bad_models)}. Must be one of {', '.join(settings.MODEL_ENDPOINTS.keys())}"
)

urls = [settings.MODEL_ENDPOINTS[model] for model in models]

# The cache historically uses checksummed addresses, need to do this for consistency
checksummed_address = to_checksum_address(address)

try:
responses = await fetch_all(urls, checksummed_address)
# TODO How to handle this when multiple models allowed at once?
# Maybe prefetch all requested non-aggregate and pass them to the aggregate
# model which will skip checking those again?
if settings.AGGREGATE_MODEL_NAME in models:
responses = await get_aggregate_model_response(checksummed_address)
else:
responses = await get_model_responses(models, checksummed_address)

ret = PassportAnalysisResponse(
address=address,
Expand All @@ -148,7 +147,37 @@ async def handle_get_analysis(
)

return ret

except Exception:
log.error("Error retrieving Passport analysis", exc_info=True)
raise PassportAnalysisError()


async def get_aggregate_model_response(checksummed_address: str):
models = [model for model in MODEL_AGGREGATION_KEYS]

model_responses = await get_model_responses(models, checksummed_address)

payload = {
"address": checksummed_address,
"data": {},
}

for model, response in zip(models, model_responses):
data = response.get("data", {})
score = data.get("human_probability", 0)
num_transactions = data.get("n_transactions", 0)
model_key = MODEL_AGGREGATION_KEYS[model]

payload["data"][f"score_{model_key}"] = score
payload["data"][f"txs_{model_key}"] = num_transactions

url = settings.MODEL_ENDPOINTS[settings.AGGREGATE_MODEL_NAME]

return await fetch_all([url], payload)


async def get_model_responses(models: List[str], checksummed_address: str):
urls = [settings.MODEL_ENDPOINTS[model] for model in models]

payload = {"address": checksummed_address}
return await fetch_all(urls, payload)
17 changes: 16 additions & 1 deletion api/scorer/settings/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@
OPTIMISM_MODEL_ENDPOINT = env(
"OPTIMISM_MODEL_ENDPOINT", default="http://localhost:80/zksync"
)
AGGREGATE_MODEL_ENDPOINT = env(
"AGGREGATE_MODEL_ENDPOINT", default="http://localhost:80/aggregate"
)

AGGREGATE_MODEL_NAME = "aggregate"


MODEL_AGGREGATION_KEYS = {
"zksync": "zk",
"polygon": "polygon",
"ethereum_activity": "eth",
"arbitrum": "arb",
"optimism": "op",
}

MODEL_ENDPOINTS = {
"ethereum_activity": ETHEREUM_MODEL_ENDPOINT,
Expand All @@ -24,6 +38,7 @@
"polygon": POLYGON_MODEL_ENDPOINT,
"arbitrum": ARBITRUM_MODEL_ENDPOINT,
"optimism": OPTIMISM_MODEL_ENDPOINT,
AGGREGATE_MODEL_NAME: AGGREGATE_MODEL_ENDPOINT,
}

MODEL_ENDPOINTS_DEFAULT = "ethereum_activity"
MODEL_ENDPOINTS_DEFAULT = "aggregate"

0 comments on commit 034d6d0

Please sign in to comment.