Skip to content

Commit

Permalink
results(prot_counts): Davis and kiba clusters #57
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub committed Dec 13, 2023
1 parent 93e1c84 commit 039beb3
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 21 deletions.
126 changes: 105 additions & 21 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,84 @@
#%%
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

sns.set_style('darkgrid')
dataset = 'kiba'

# cols are code,SMILE,prot_seq,pkd,prot_id
df_full = pd.read_csv(f'../data/DavisKibaDataset/{dataset}/nomsa_binary_original_binary/full/XY.csv',
index_col='code')
prot_counts = df_full.index.value_counts()

# %% get subgroup mse
subset = 'train' #'test'

for model_type in ['EDI', 'DG']:
if model_type == 'EDI':
model_path = lambda x: f'results/model_media/{subset}_set_pred/EDIM_{dataset}{x}D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_{subset}Pred.csv'
elif model_type == 'DG':
if dataset == 'davis':
model_path = lambda x: f'results/model_media/{subset}_set_pred/DGM_{dataset}{x}D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_{subset}Pred.csv'
elif dataset == 'kiba':
model_path = lambda x: f'results/model_media/{subset}_set_pred/DGM_{dataset}{x}D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_{subset}Pred.csv'

data_clust = {} # {cluster: [mse1, mse2, ...], ...}
for fold in range(5):
pred = pd.read_csv(model_path(fold), index_col='name')

for k in prot_counts.keys():
matched = pred[pred.index == k]
if matched.empty: # will be empty if not in this fold
continue
mse = ((matched.pred - matched.actual)**2).mean()

# add main mse to dict
data_clust[k] = data_clust.get(k, []) + [mse]

# merge counts with mse
x = []
y = []
z = []
for k in data_clust.keys():
x += [prot_counts[k]] * len(data_clust[k])
y += data_clust[k]
z += [k] * len(data_clust[k])

# scatter plot with x axis as count and y axis as mse
plt.figure(figsize=(10, 5))
ax = sns.scatterplot(x=x, y=y, hue=z)
# line of best fit
m, b = np.polyfit(x, y, 1)
lines = plt.plot(x, m*np.array(x) + b, color='black', linestyle='dotted',
label=f'y={m*10000:.2f}e-4x+{b:.2f}', linewidth=2)

# correlation
corr = np.corrcoef(x, y)[0, 1]
plt.xlabel('Protein Count in Entire Dataset')
plt.ylabel('MSE')
plt.ylim(0, 1.6)
plt.legend(handles=[lines[0]], loc='upper left', title=f'Correlation: {corr:.3f}')
plt.title(f'Protein Count vs {subset} MSE ({model_type} Model)')
plt.savefig(f"{model_type}_kiba.png")
plt.clf()


exit()













# %%
import pprint
from src.utils.mmseq2 import MMseq2Runner
Expand Down Expand Up @@ -88,7 +169,8 @@ def get_cluster_details(clust):


subset = 'test' #'test'
for model_type in ['EDI']:
verbose = False
for model_type in ['EDI', 'DG']:
if model_type == 'EDI':
model_path = lambda x: f'results/model_media/{subset}_set_pred/EDIM_davis{x}D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_{subset}Pred.csv'
elif model_type == 'DG':
Expand Down Expand Up @@ -129,31 +211,33 @@ def get_cluster_details(clust):
corr = np.corrcoef(x, y)[0, 1]
plt.xlabel('Number of Proteins in cluster')
plt.ylabel('MSE')
plt.ylim([0.0,1.6])
plt.legend(handles=[lines[0]], loc='upper left', title=f'Correlation: {corr:.3f}')
plt.title(f'Subgroup Size vs {subset} MSE ({model_type} Model)')
plt.show()
plt.savefig(f'{model_type}.png')
plt.clf()


# print worse performers
print('WORST PERFORMERS')
for k in data_clust.keys():
if np.mean(data_clust[k]) > 0.5:
clust = list(clusters[k])
print(f'\n\n### Cluster {k} has {len(clust)} proteins and mean mse of {np.mean(data_clust[k]):.3f} '
f'with std {np.std(data_clust[k]):.3f}')
get_cluster_details(clust)

# print best performers
print('')
print('#'*50)
print('BEST PERFORMERS')
for k in data_clust.keys():
if np.mean(data_clust[k]) < 0.15:
clust = list(clusters[k])
print(f'\n\n### Cluster {k} has {len(clust)} proteins and mean mse of {np.mean(data_clust[k]):.3f} '
f'with std {np.std(data_clust[k]):.3f}')
get_cluster_details(clust)
if verbose:
# print worse performers
print('WORST PERFORMERS')
for k in data_clust.keys():
if np.mean(data_clust[k]) > 0.5:
clust = list(clusters[k])
print(f'\n\n### Cluster {k} has {len(clust)} proteins and mean mse of {np.mean(data_clust[k]):.3f} '
f'with std {np.std(data_clust[k]):.3f}')
get_cluster_details(clust)

# print best performers
print('')
print('#'*50)
print('BEST PERFORMERS')
for k in data_clust.keys():
if np.mean(data_clust[k]) < 0.15:
clust = list(clusters[k])
print(f'\n\n### Cluster {k} has {len(clust)} proteins and mean mse of {np.mean(data_clust[k]):.3f} '
f'with std {np.std(data_clust[k]):.3f}')
get_cluster_details(clust)

# %% Cluster 43 is an outlier with the worse performance:

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 039beb3

Please sign in to comment.