Skip to content

Commit

Permalink
Merge pull request #58 from jyaacoub/development
Browse files Browse the repository at this point in the history
Development
  • Loading branch information
jyaacoub authored Nov 14, 2023
2 parents a0a2284 + 2d71560 commit e6077d0
Show file tree
Hide file tree
Showing 46 changed files with 101,072 additions and 226 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -209,5 +209,6 @@ lib/mgltools_x86_64Linux2_1.5.7p1.tar.gz
log_test/
slurm_tests/
slurm_out_DDP/
slurm_out*/
/*.sh
results/model_checkpoints/ours/*.model*
44 changes: 23 additions & 21 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
numpy==1.23.5
pandas==1.5.3
tqdm==4.65.0
rdkit==2023.3.1
scipy==1.10.1
numpy
pandas
tqdm
rdkit
scipy

# for generating figures:
matplotlib==3.7.1
seaborn==0.11.2
statannotations==0.6.0
matplotlib
seaborn
statannotations

lifelines==0.27.7 # used for concordance index calc
#biopython # used for cmap
lifelines

# model building
torch==2.0.1
torch-geometric==2.3.1
transformers==4.31.0 # huggingface needed for esm
torch
torch-geometric
transformers

# optional:
torchsummary==1.5.1
tabulate==0.9.0 # for torch_geometric.nn.summary
ipykernel==6.23.1
plotly==5.14.1
requests==2.31.0
#ray[tune]
torchsummary
tabulate
ipykernel
plotly
requests
ray[tune]

submitit==1.4.5
ProDy==2.4.1
submitit
ProDy

# for chemgpt
selfies
32 changes: 32 additions & 0 deletions docs/requirements_versions.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
numpy==1.23.5
pandas==1.5.3
tqdm==4.65.0
rdkit==2023.3.1
scipy==1.10.1

# for generating figures:
matplotlib==3.7.1
seaborn==0.11.2
statannotations==0.6.0

lifelines==0.27.7 # used for concordance index calc
#biopython # used for cmap

# model building
torch==2.0.1
torch-geometric==2.3.1
transformers==4.31.0 # huggingface needed for esm

# optional:
torchsummary==1.5.1
tabulate==0.9.0 # for torch_geometric.nn.summary
ipykernel==6.23.1
plotly==5.14.1
requests==2.31.0
#ray[tune]

submitit==1.4.5
ProDy==2.4.1

# for chemgpt
selfies==1.0.4
28 changes: 15 additions & 13 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
# %%
from src.data_analysis.figures import prepare_df, fig3_edge_feat
from src.utils import config
#%%
from src.data_analysis.stratify_protein import check_davis_names, kinbase_to_df

from transformers import AutoTokenizer, AutoModel
df = kinbase_to_df()
# %%
import json
prot_dict = json.load(open('/home/jyaacoub/projects/data/davis/proteins.txt', 'r'))
# %%
# returns a dictionary of davis protein names (keys) and a truple of the protein name, main family, and subgroup (values)
prots = check_davis_names(prot_dict, df)

# %% plot histogram of main families and their counts
import seaborn as sns
import pandas as pd

df = prepare_df('results/model_media/model_stats.csv')
main_families = [v[1] for v in prots.values()]
main_families = pd.Series(main_families)
sns.histplot(main_families)

# %%
fig3_edge_feat(df, show=True, exclude=[])

# %%
print('test')

#### ChemGPT ####

tokenizer = AutoTokenizer.from_pretrained("ncfrey/ChemGPT-4.7M")
model = AutoModel.from_pretrained("ncfrey/ChemGPT-4.7M")
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
{
"train_loss": [
1.9954961412380228,
1.1123259179788616,
1.0628154035987316,
1.0195641736376337,
0.9921986040354303,
0.9692184450681758,
0.9488197967360636,
0.940734565818468,
0.9137962934833053,
0.9058670597727847,
0.8862639353074858,
0.8738508635093417,
0.8678539732256821,
0.8494533173146501,
0.8396369166078821,
0.8195936667477399,
0.8189072838399263,
0.7990547935354763,
0.7925020887030173,
0.7814916852136365,
0.7716743027445225,
0.7743883602716753,
0.7596010671087862,
0.7445095050750317,
0.7552647592120729,
0.7467957078685276,
0.7425954988209811,
0.7552545020373258,
0.7356984313346643,
0.7358773813080207,
0.7404806569746111,
0.742597665621248,
0.7288618015735286,
0.7157279677128634,
0.7089876080574715,
0.7249864521807274,
0.7202833260994701,
0.7165739839995461,
0.7198128473854303,
0.7385284647065322,
0.7309793036150853,
0.7196460504958455,
0.7323818235880638,
0.7173825457807531,
0.7192918718256781,
0.722492459147707,
0.7188923951584549,
0.7100095051563288,
0.7068662808268471,
0.7108152348680042,
0.714867272776906,
0.7058873030200469,
0.7120912533698488,
0.7118631477294638,
0.6987585334469919,
0.7243739486037366,
0.7150540050959825,
0.7067155465895755,
0.7103592198325601,
0.7352949746562976,
0.7051169296924388,
0.7122998241236252,
0.7036621218497774,
0.7038248123732948,
0.7041913827106488,
0.7179739314700122,
0.7103856882179337,
0.7186218756951589,
0.7088906502565452,
0.7301404662347869,
0.7173615305198241,
0.7239452102839683,
0.7309273347851977,
0.7180232846268774,
0.7328828894079918,
0.727330649090288,
0.723351289419276,
0.7066450789405972,
0.7188684371242349,
0.7147803407459132,
0.7227345445086207,
0.7226420441658887,
0.7060261584933748,
0.7267584679977952,
0.7306896013628065,
0.7264216862767275,
0.7150264033652117,
0.7042778410116391,
0.7083051127290964,
0.7082227646595741,
0.712323370999533,
0.6989217922833245,
0.7070858577694908,
0.6963187569620467,
0.6941914890390054,
0.6874943163526137,
0.6820477824410902,
0.6814433606423899,
0.6704526256473191,
0.668551255891149,
0.6652297119453418,
0.661144870951508,
0.6565740831624881,
0.6341712325152042,
0.6606567606280467,
0.643499315848195,
0.6269802997976673
],
"val_loss": [
1.1018557784788425,
1.2535118042961158,
1.2187856405981676,
1.2760470970204116,
1.41563035914081,
1.3298013821858254,
1.2990144251502456,
1.2875569346781321,
1.3214407840344964,
1.2794410279208683,
1.319744275546987,
1.3295082387811117,
1.2344647842681908,
1.2835332232072076,
1.0444040316799903,
1.0201239069589967,
0.8929435918643346,
0.7577486736679334,
0.7760434191739929,
0.7643616874784692,
0.7721006384292469,
0.7641122302501998,
0.7538701039368072,
0.7168129901525325,
0.7462786953706723,
0.7326808813011554,
0.7197824846747108,
0.7124754123768844,
0.6933985606487016,
0.6861601175938165,
0.6809609269401866,
0.6760130381557794,
0.6701235365950857,
0.6707020423568752,
0.6693344269013775,
0.6909623357809388,
0.682249966654432,
0.7903768657953931,
0.7202901590730896,
0.7126389422280862,
0.7491712721756998,
0.7206178106454126,
0.6651267873543619,
0.6644183549333561,
0.7078817178355118,
0.6839179164416824,
0.7086175708553498,
0.7268013580026418,
0.7167100477489783,
0.8461501004681222,
0.7416352115425207,
0.8179783209400103,
0.8254071828286963,
0.8371337609761188,
0.9208616896025605,
0.97088622076816,
0.8756023296731961,
1.069556474998857,
0.9624758456108383,
0.859606430698928,
1.0121137776074156,
1.0009158287358066,
1.0606968613440353,
1.2167002147937243,
1.5448616045502435,
1.4364397041036014,
1.4733317930053558,
1.7337123687354337,
1.7061408191262044,
1.7082935327713469,
1.3157687119529706,
1.5440313264941878,
1.3741871246334527,
1.2794151946750831,
1.316956363407178,
1.0836214681243033,
1.4886349551956606,
1.446724139446541,
1.2230758408601685,
1.2782749346832423,
0.9854454206203861,
1.434194633492013,
0.9626338091592379,
1.1336133684715735,
1.1052076671369357,
1.161011285033176,
1.1970799049420233,
1.0871652196598265,
1.3620113483547702,
1.1277231313020353,
1.308906744110637,
1.5659388677027506,
0.8891798744897103,
1.1205095070768643,
1.0892479312224972,
0.7972673116668036,
0.8973261003230031,
1.087875069614574,
1.1212975865093975,
1.0956825447326477,
0.9849263438656184,
0.8330175951020216,
0.9281886269841179,
1.0319483467708424,
0.9451814371622703,
1.0264084387830237,
1.1390326986794022
],
"best_epoch": 44
}
Loading

0 comments on commit e6077d0

Please sign in to comment.