Skip to content

Commit

Permalink
fix(prepare_df): parse for GVPL_ESM model results #90
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub committed Jul 3, 2024
1 parent 3b8b0a8 commit f442919
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
6 changes: 4 additions & 2 deletions playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@
'aflow': ('nomsa', 'aflow', 'original', 'binary'),
# 'gvpP': ('gvp', 'binary', 'original', 'binary'),
# 'gvpL': ('nomsa', 'binary', 'gvp', 'binary'),
'gvpL': ('nomsa', 'binary', 'gvp', 'binary'),
# 'gvpL': ('nomsa', 'binary', 'gvp', 'binary'),
# 'aflow_ring3': ('nomsa', 'aflow_ring3', 'original', 'binary'),
'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'),
# 'gvpL_aflow_rng3': ('nomsa', 'aflow_ring3', 'gvp', 'binary'),
#GVPL_ESMM_davis3D_nomsaF_aflowE_48B_0.00010636872718329864LR_0.23282479481785903D_2000E_gvpLF_binaryLE
'gvpl_esm_aflow': ('ESM', 'aflow', 'gvp', 'binary'),
}

df = prepare_df()
fig, axes = fig_combined(df, datasets=['davis'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(8,5))
fig_scale=(10,5), add_stats=True)
plt.xticks(rotation=45)


Expand Down
1 change: 1 addition & 0 deletions src/analysis/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ def prepare_df(csv_p:str=cfg.MODEL_STATS_CSV, old_csv_p:str=None) -> pd.DataFram
df['dropout'] = df['run'].str.extract(r'_(\d+\.?\d*)D_', expand=False).astype(float)

# ESM models
df.loc[df['run'].str.contains('ESM') & df['run'].str.contains('nomsaF'), 'feat'] = 'ESM'
df.loc[df['run'].str.contains('EDM') & df['run'].str.contains('nomsaF'), 'feat'] = 'ESM'
df.loc[df['run'].str.contains('EDAM'), 'feat'] += '-ESM'
df.loc[df['run'].str.contains('EDIM') & df['run'].str.contains('nomsaF'), 'feat'] = 'ESM'
Expand Down
2 changes: 1 addition & 1 deletion src/data_prep/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def save_subset(self, idxs:Iterable[int]|data.Sampler|data.DataLoader,
path = os.path.join(self.root, subset_name)
os.makedirs(path, exist_ok=True)
sub_df.to_csv(os.path.join(path, self.processed_file_names[0])) # redundant save since it is not used and mainly just for tracking prots.
sub_df.to_csv(os.path.join(path, self.processed_file_names[3]))
sub_df.to_csv(os.path.join(path, self.processed_file_names[3])) # clean_XY.csv
torch.save(sub_prots, os.path.join(path, self.processed_file_names[1]))
torch.save(sub_lig, os.path.join(path, self.processed_file_names[2]))
return path
Expand Down

0 comments on commit f442919

Please sign in to comment.