Skip to content

Commit

Permalink
simplify OO + categories in db
Browse files Browse the repository at this point in the history
  • Loading branch information
daviidarr committed Apr 9, 2024
1 parent c1cbcb6 commit e2dc995
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 136 deletions.
4 changes: 2 additions & 2 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from plex.debank_api import DebankAPI
from utils.async_utils import safe_gather
from utils.db import PlexDB, SQLiteDB, RawDataDB, S3JsonRawDataDB
from utils.db import SQLiteDB, SQLiteDB, RawDataDB, S3JsonRawDataDB

if __name__ == '__main__':
if sys.argv[1] =='snapshot':
Expand All @@ -24,7 +24,7 @@
plex_db_params = copy.deepcopy(parameters['input_data']['plex_db'])
plex_db_params['remote_file'] = plex_db_params['remote_file'].replace('.db', f"_{parameters['profile']['debank_key']}.db")

plex_db: PlexDB = SQLiteDB(plex_db_params, secrets)
plex_db: SQLiteDB = SQLiteDB(plex_db_params, secrets)
raw_data_db: RawDataDB = RawDataDB.build_RawDataDB(parameters['input_data']['raw_data_db'], secrets)
api = DebankAPI(raw_data_db, plex_db, parameters)

Expand Down
75 changes: 0 additions & 75 deletions config/categories.yaml

This file was deleted.

1 change: 1 addition & 0 deletions config/params.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
profile:
debank_key: "0b9786c662bff596482c995ef9c654aa3663a120"
addresses:
- "0x7f8DA5FBD700a134842109c54ABA576D5c3712b8"
- "0xFaf2A8b5fa78cA2786cEf5F7e19f6942EC7cB531"
Expand Down
6 changes: 3 additions & 3 deletions plex/debank_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
import streamlit as st

from utils.async_utils import safe_gather
from utils.db import RawDataDB, PlexDB
from utils.db import RawDataDB, SQLiteDB


class DebankAPI:
endpoints = ["all_complex_protocol_list", "all_token_list", "all_nft_list"]
api_url = "https://pro-openapi.debank.com/v1"
def __init__(self, json_db: RawDataDB, plex_db: PlexDB, parameters: Dict[str, Any]):
def __init__(self, json_db: RawDataDB, plex_db: SQLiteDB, parameters: Dict[str, Any]):
self.parameters = parameters
self.json_db: RawDataDB = json_db
self.plex_db: PlexDB = plex_db
self.plex_db: SQLiteDB = plex_db

def get_credits(self) -> float:
response = requests.get(f'{self.api_url}/account/units',
Expand Down
12 changes: 5 additions & 7 deletions plex/plex.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,16 @@


class PnlExplainer:
def __init__(self):
self.categories_path = os.path.join(os.getcwd(), 'config', 'categories.yaml')
with open(self.categories_path, 'r') as f:
self.categories = yaml.safe_load(f)
def __init__(self, categories: dict[str, str]):
self.categories = categories

def validate_categories(self, data) -> bool:
if missing_category := set(data['asset']) - set(self.categories.keys()):
st.warning(f"Categories need to be updated. Please categorize the following assets: {missing_category}")
return False
if missing_underlying := set(self.categories.values()) - set(data['asset']):
st.warning(f"I need underlying {missing_underlying} to have a position, maybe get some dust? Sorry...")
return False
# if missing_underlying := set(self.categories.values()) - set(data['asset']):
# st.warning(f"I need underlying {missing_underlying} to have a position, maybe get some dust? Sorry...")
# return False
return True

def explain(self, start_snapshot: pd.DataFrame, end_snapshot: pd.DataFrame) -> DataFrame:
Expand Down
10 changes: 5 additions & 5 deletions pnl_explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from plex.plex import PnlExplainer
from utils.async_utils import safe_gather
from utils.db import SQLiteDB, RawDataDB, PlexDB
from utils.db import SQLiteDB, RawDataDB
from plex.debank_api import DebankAPI

assert (sys.version_info >= (3, 10)), "Please use Python 3.10 or higher"
Expand All @@ -32,12 +32,12 @@
plex_db_params = copy.deepcopy(st.session_state.parameters['input_data']['plex_db'])
plex_db_params['remote_file'] = plex_db_params['remote_file'].replace('.db',
f"_{st.session_state.parameters['profile']['debank_key']}.db")
st.session_state.plex_db: PlexDB = SQLiteDB(plex_db_params, st.secrets)
st.session_state.plex_db: SQLiteDB = SQLiteDB(plex_db_params, st.secrets)
raw_data_db: RawDataDB = RawDataDB.build_RawDataDB(st.session_state.parameters['input_data']['raw_data_db'], st.secrets)
st.session_state.api = DebankAPI(json_db=raw_data_db,
plex_db=st.session_state.plex_db,
parameters=st.session_state.parameters)
st.session_state.pnl_explainer = PnlExplainer()
st.session_state.pnl_explainer = PnlExplainer(st.session_state.plex_db.query_categories())

addresses = st.session_state.parameters['profile']['addresses']
risk_tab, pnl_tab = st.tabs(
Expand Down Expand Up @@ -90,8 +90,8 @@
edited_categorization = st.data_editor(categorization, use_container_width=True)['underlying'].to_dict()
if st.form_submit_button("Override categorization"):
st.session_state.pnl_explainer.categories = edited_categorization
with open(st.session_state.pnl_explainer.categories_path, 'w') as f:
yaml.dump(edited_categorization, f)
st.session_state.plex_db.overwrite_categories(edited_categorization)
st.session_state.plex_db.upload_to_s3()
st.success("Categories updated (not exposure!)")

with pnl_tab:
Expand Down
65 changes: 24 additions & 41 deletions utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path

import boto3
import yaml
from botocore.exceptions import ClientError
import pandas as pd
import sqlite3
Expand Down Expand Up @@ -89,45 +90,7 @@ def all_timestamps(self, address: str, table_name: TableType) -> list[int]:
if file['Key'].endswith('.json') and address in file['Key']]


class PlexDB(ABC):
'''
Abstract class for PlexDB, where we put snapshots, one table per address
'''
@abstractmethod
def query_table_at(self, addresses: list[str], timestamp: int, table_name: TableType) -> pd.DataFrame:
raise NotImplementedError

@abstractmethod
def query_table_between(self, addresses: list[str], start_timestamp: int, end_timestamp: int, table_name: TableType) -> pd.DataFrame:
raise NotImplementedError

@abstractmethod
def insert_table(self, df: pd.DataFrame, table_name: TableType) -> None:
raise NotImplementedError

@abstractmethod
def all_timestamps(self, address: str, table_name: TableType) -> list[int]:
raise NotImplementedError

def last_updated(self, address: str, table_name: TableType) -> tuple[datetime, pd.DataFrame]:
if all_timestamps := self.all_timestamps(address, table_name):
timestamp = max(all_timestamps)
latest_table = self.query_table_at([address], timestamp, table_name)
return datetime.fromtimestamp(timestamp, tz=timezone.utc), latest_table
else:
return datetime(1970, 1, 1, tzinfo=timezone.utc), {}


class SQLiteDB(PlexDB):
plex_schema = {'chain': 'TEXT',
'protocol': 'TEXT',
'hold_mode': 'TEXT',
'type': 'TEXT',
'asset': 'TEXT',
'amount': 'REAL',
'price': 'REAL',
'value': 'REAL',
'timestamp': 'INTEGER'}
class SQLiteDB:
def __init__(self, config: dict, secrets: dict):
if 'bucket_name' in config and 'remote_file' in config:
# if bucket_name is in config, we are using s3 and download the file to ~
Expand Down Expand Up @@ -164,6 +127,14 @@ def __init__(self, config: dict, secrets: dict):
os.chmod(local_file, 0o777)
self.cursor = self.conn.cursor()

def last_updated(self, address: str, table_name: TableType) -> tuple[datetime, pd.DataFrame]:
if all_timestamps := self.all_timestamps(address, table_name):
timestamp = max(all_timestamps)
latest_table = self.query_table_at([address], timestamp, table_name)
return datetime.fromtimestamp(timestamp, tz=timezone.utc), latest_table
else:
return datetime(1970, 1, 1, tzinfo=timezone.utc), pd.DataFrame()

def upload_to_s3(self):
s3 = boto3.client('s3',
aws_access_key_id=self.secrets['AWS_ACCESS_KEY_ID'],
Expand All @@ -175,7 +146,6 @@ def insert_table(self, df: pd.DataFrame, table_name: TableType) -> None:
for address, data in df.groupby('address'):
table = f"{table_name}_{address}"
data.drop(columns='address').to_sql(table, self.conn, if_exists='append', index=False)
self.conn.commit()

def query_table_at(self, addresses: list[str], timestamp: int, table_name: TableType) -> pd.DataFrame:
return pd.concat([pd.read_sql_query(f'SELECT * FROM {table_name}_{address} WHERE timestamp = {timestamp}', self.conn)
Expand All @@ -192,4 +162,17 @@ def all_timestamps(self, address: str, table_name: TableType) -> list[int]:
self.cursor.execute(f'SELECT DISTINCT timestamp FROM {table_name}_{address}')
rows = self.cursor.fetchall()
return [row[0] for row in rows]


def query_categories(self) -> dict:
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table'", self.conn)
if 'categories' not in tables.values:
pd.DataFrame(columns=['asset', 'underlying']).to_sql('categories', self.conn, index=False)
return {}
return pd.read_sql_query('SELECT * FROM categories', self.conn).set_index('asset')['underlying'].to_dict()

def overwrite_categories(self, categories: dict) -> None:
# if True:
# with open(os.path.join(os.getcwd(), 'config', 'categories_SAVED.yaml'), 'r') as file:
# categories = yaml.safe_load(file)
pd.DataFrame({'asset':categories.keys(), 'underlying': categories.values()}).to_sql('categories', self.conn, index=False, if_exists='replace')
self.conn.commit()
6 changes: 3 additions & 3 deletions utils/streamlit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from plotly import express as px
from st_aggrid import AgGrid, GridOptionsBuilder

from utils.db import PlexDB
from utils.db import SQLiteDB


def load_parameters() -> dict:
Expand Down Expand Up @@ -49,7 +49,7 @@ def load_parameters() -> dict:
return st.session_state.parameters


def prompt_plex_interval(plex_db: PlexDB, addresses: list[str]) -> tuple[int, int]:
def prompt_plex_interval(plex_db: SQLiteDB, addresses: list[str]) -> tuple[int, int]:
date_col, time_col = st.columns(2)
now_datetime = datetime.now()
with time_col:
Expand Down Expand Up @@ -132,7 +132,7 @@ def download_button(df: pd.DataFrame, label: str, file_name: str, file_type='tex
mime=file_type
)

def download_db_button(db: PlexDB, label: str, file_name: str, file_type='application/x-sqlite3'):
def download_db_button(db: SQLiteDB, label: str, file_name: str, file_type='application/x-sqlite3'):
with open(db.data_location['local_file'], "rb") as file:
st.sidebar.download_button(
label=label,
Expand Down

0 comments on commit e2dc995

Please sign in to comment.