Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Development #65

Merged
merged 3 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#%%
import matplotlib.pyplot as plt
import seaborn as sns
from src.data_analysis.figures import fig6_protein_appearance

# %%
from matplotlib import pyplot as plt
from src.data_analysis.figures import fig0_dataPro_overlap

sns.set(style="darkgrid")
fig6_protein_appearance()
plt.savefig('results/figures/fig6_protein_appearance_dataset.png', dpi=300)

#%%
for data in ['kiba', 'davis']:
fig0_dataPro_overlap(data=data)
plt.savefig(f'results/figures/fig0_pro_overlap_{data}.png', dpi=300, bbox_inches='tight')
plt.clf()
# %%
Binary file added results/figures/fig0_pro_overlap_davis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added results/figures/fig0_pro_overlap_kiba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified results/figures/fig2_pro_feat_cindex.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified results/figures/fig2_pro_feat_pearson.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified results/figures/fig3_edge_feat_cindex.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified results/figures/fig3_edge_feat_pearson.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
114 changes: 106 additions & 8 deletions src/data_analysis/figures.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,68 @@
import os
import matplotlib as mpl
import matplotlib.pyplot as plt
from collections import Counter
import os, pickle, json

import pandas as pd
import numpy as np

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from statannotations.Annotator import Annotator

from src.utils import config as cfg
from src.utils.loader import Loader

def fig0_dataPro_overlap(data:str='davis', data_root:str=cfg.DATA_ROOT, verbose=False):
data_path = f'{data_root}/{data}'

Y = pickle.load(open(f'{data_path}/Y', "rb"), encoding='latin1')
row_i, col_i = np.where(np.isnan(Y)==False)
test_fold = json.load(open(f"{data_path}/folds/test_fold_setting1.txt"))
train_fold = json.load(open(f"{data_path}/folds/train_fold_setting1.txt"))

# loading up train and test protein indices
train_flat = [i for fold in train_fold for i in fold]
test_protein_indices = col_i[test_fold]
train_protein_indices = col_i[train_flat]

# Overlap in train and test...
overlap = set(train_protein_indices).intersection(set(test_protein_indices))
if verbose:
print(f'number of unique proteins in train: {len(set(train_protein_indices))}')
print(f'number of unique proteins in test: {len(set(test_protein_indices))}')
print(f'total number of unique proteins: {max(col_i)+1}')
print(f'Intersection of train and test protein indices: {len(overlap)}')

# counts of overlaping proteins
test_counts = Counter(test_protein_indices)
train_counts = Counter(train_protein_indices)

overlap_test_counts = {k: test_counts[k] for k in overlap}
overlap_train_counts = {k: train_counts[k] for k in overlap}

# normalized for set size
norm_overlap_test_counts = {k: v/len(test_protein_indices) for k,v in overlap_test_counts.items()}
norm_overlap_train_counts = {k: v/len(train_protein_indices) for k,v in overlap_train_counts.items()}

# plot overlap counts
plt.figure(figsize=(15,10))
plt.subplot(2,1,1)
plt.bar(overlap_train_counts.keys(), overlap_train_counts.values(), label='train', width=1.0)
plt.bar(overlap_test_counts.keys(), overlap_test_counts.values(), label='test', width=1.0)
plt.xlabel('protein index')
plt.ylabel('count')
plt.title(f'Counts of proteins in train and test ({data})')
plt.legend()

plt.subplot(2,1,2)
plt.bar(norm_overlap_train_counts.keys(), norm_overlap_train_counts.values(), label='train', width=1.0)
plt.bar(norm_overlap_test_counts.keys(), norm_overlap_test_counts.values(), label='test', width=1.0)
plt.xlabel('protein index')
plt.ylabel('Normalized Counts')
plt.title(f'Normalized counts by dataset size of proteins in train and test ({data})')
plt.legend()
plt.tight_layout()

# Figure 1 - Protein overlap cindex difference (nomsa)
def fig1_pro_overlap(df, sel_col='cindex', verbose=False, show=False, context='paper'):
filtered_df = df[(df['feat'] == 'nomsa')
Expand Down Expand Up @@ -44,6 +99,8 @@ def fig1_pro_overlap(df, sel_col='cindex', verbose=False, show=False, context='p

# Show the plot
if show: plt.show()

return plot_df

# Figure 2 - node feature cindex difference
# Features -> nomsa, msa, shannon, and esm
Expand Down Expand Up @@ -186,7 +243,7 @@ def fig3_edge_feat(df, verbose=False, sel_col='cindex', exclude=['af2-anm'], sho

# Figure 4: violin plot with error bars for Cross-validation results to show significance among pro feats
def fig4_pro_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex', exclude=[],
show=False, add_labels=True, add_stats=True):
show=False, add_labels=True, add_stats=True, ax=None):
# Extract relevant data
filtered_df = df[(df['edge'] == 'binary') & (~df['overlap']) & (df['lig_feat'].isna())]

Expand All @@ -207,7 +264,7 @@ def fig4_pro_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex

# Get values for each node feature
plot_data = [nomsa, msa, shannon, esm]
ax = sns.violinplot(data=plot_data)
ax = sns.violinplot(data=plot_data, ax=ax)
ax.set_xticklabels(['nomsa', 'msa', 'shannon', 'esm'])
ax.set_ylabel(sel_col)
ax.set_xlabel('Features')
Expand All @@ -230,7 +287,7 @@ def fig4_pro_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex

# Figure 5: violin plot with error bars for Cross-validation results to show significance among edge feats
def fig5_edge_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex', exclude=[],
show=False, add_labels=True, add_stats=True):
show=False, add_labels=True, add_stats=True, ax=None):
filtered_df = df[(df['feat'] == 'nomsa') & (~df['overlap']) & (df['lig_feat'].isna())]
filtered_df = filtered_df[(filtered_df['data'] == sel_dataset) & (filtered_df['fold'] != '')]

Expand All @@ -244,7 +301,7 @@ def fig5_edge_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cinde

# plot violin plot with annotations
plot_data = [binary, simple, anm, af2]
ax = sns.violinplot(data=plot_data)
ax = sns.violinplot(data=plot_data, ax=ax)
ax.set_xticklabels(['binary', 'simple', 'anm', 'af2'])
ax.set_ylabel(sel_col)
ax.set_xlabel('Edge type')
Expand Down Expand Up @@ -285,6 +342,42 @@ def fig6_protein_appearance(datasets=['kiba', 'PDBbind'], show=False):
plt.tight_layout()

if show: plt.show()


def fig_combined(df, datasets=['PDBbind','davis', 'kiba'], metrics=['cindex', 'mse'],
fig_callable=fig4_pro_feat_violin,
show=False, **kwargs):
# Create subplots with datasets as columns and cols as rows
fig, axes = plt.subplots(len(metrics), len(datasets),
figsize=(5*len(datasets), 4*len(metrics)))
for i, dataset in enumerate(datasets):
for j, metric in enumerate(metrics):
# Set current subplot
ax = axes[j, i]

fig_callable(df, sel_col=metric, sel_dataset=dataset, show=False,
ax=ax, **kwargs)

# Add titles only to the top row and left column
if j == 0:
ax.set_title(f'{dataset}')
ax.set_xlabel('')
ax.set_xticklabels([])
elif j < len(metrics)-1: # middle row
ax.set_xlabel('')
ax.set_xticklabels([])
ax.set_title('')
else: # bottom row
ax.set_title('')

if i == 0:
ax.set_ylabel(metric)
else:
ax.set_ylabel('')

plt.tight_layout() # Adjust layout to prevent clipping of titles
if show: plt.show()
return fig, axes

def prepare_df(csv_p:str=cfg.MODEL_STATS_CSV, old_csv_p:str=None) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -380,4 +473,9 @@ def prepare_df(csv_p:str=cfg.MODEL_STATS_CSV, old_csv_p:str=None) -> pd.DataFram
plt.savefig(f"results/figures/fig5_edge_feat_violin_{dataset}_{col}.png", dpi=300, bbox_inches='tight')
plt.clf()

# %%
#%% dataset comparisons
plot_df = fig1_pro_overlap(df, sel_col='mse', verbose=verbose, show=False)

# performance drop values:
grp = plot_df.groupby(['data', 'overlap']).mean()
grp[grp.index.get_level_values(1)].values - grp[~grp.index.get_level_values(1)].values