diff --git a/playground.py b/playground.py index 589a502f..84383ab0 100644 --- a/playground.py +++ b/playground.py @@ -1,11 +1,11 @@ -#%% -import matplotlib.pyplot as plt -import seaborn as sns -from src.data_analysis.figures import fig6_protein_appearance - +# %% +from matplotlib import pyplot as plt +from src.data_analysis.figures import fig0_dataPro_overlap -sns.set(style="darkgrid") -fig6_protein_appearance() -plt.savefig('results/figures/fig6_protein_appearance_dataset.png', dpi=300) +#%% +for data in ['kiba', 'davis']: + fig0_dataPro_overlap(data=data) + plt.savefig(f'results/figures/fig0_pro_overlap_{data}.png', dpi=300, bbox_inches='tight') + plt.clf() # %% diff --git a/results/figures/fig0_pro_overlap_davis.png b/results/figures/fig0_pro_overlap_davis.png new file mode 100644 index 00000000..c9124bda Binary files /dev/null and b/results/figures/fig0_pro_overlap_davis.png differ diff --git a/results/figures/fig0_pro_overlap_kiba.png b/results/figures/fig0_pro_overlap_kiba.png new file mode 100644 index 00000000..5309100e Binary files /dev/null and b/results/figures/fig0_pro_overlap_kiba.png differ diff --git a/results/figures/fig2_pro_feat_cindex.png b/results/figures/fig2_pro_feat_cindex.png index 7a1219de..8bb81b24 100644 Binary files a/results/figures/fig2_pro_feat_cindex.png and b/results/figures/fig2_pro_feat_cindex.png differ diff --git a/results/figures/fig2_pro_feat_pearson.png b/results/figures/fig2_pro_feat_pearson.png index 3aa3ccb5..b1cb9d66 100644 Binary files a/results/figures/fig2_pro_feat_pearson.png and b/results/figures/fig2_pro_feat_pearson.png differ diff --git a/results/figures/fig3_edge_feat_cindex.png b/results/figures/fig3_edge_feat_cindex.png index 6e4ef24f..d67eaf95 100644 Binary files a/results/figures/fig3_edge_feat_cindex.png and b/results/figures/fig3_edge_feat_cindex.png differ diff --git a/results/figures/fig3_edge_feat_pearson.png b/results/figures/fig3_edge_feat_pearson.png index 966367f8..5306458a 100644 Binary files a/results/figures/fig3_edge_feat_pearson.png and b/results/figures/fig3_edge_feat_pearson.png differ diff --git a/results/figures/fig_combined_edgeViolin_CI-MSE-Pearson.png b/results/figures/fig_combined_edgeViolin_CI-MSE-Pearson.png new file mode 100644 index 00000000..e71ae1d3 Binary files /dev/null and b/results/figures/fig_combined_edgeViolin_CI-MSE-Pearson.png differ diff --git a/results/figures/fig_combined_edgeViolin_CI-MSE.png b/results/figures/fig_combined_edgeViolin_CI-MSE.png new file mode 100644 index 00000000..e2a6fbeb Binary files /dev/null and b/results/figures/fig_combined_edgeViolin_CI-MSE.png differ diff --git a/results/figures/fig_combined_proViolin_CI-MSE-Pearson.png b/results/figures/fig_combined_proViolin_CI-MSE-Pearson.png new file mode 100644 index 00000000..db85cf07 Binary files /dev/null and b/results/figures/fig_combined_proViolin_CI-MSE-Pearson.png differ diff --git a/results/figures/fig_combined_proViolin_CI-MSE.png b/results/figures/fig_combined_proViolin_CI-MSE.png new file mode 100644 index 00000000..e5e3e1c2 Binary files /dev/null and b/results/figures/fig_combined_proViolin_CI-MSE.png differ diff --git a/src/data_analysis/figures.py b/src/data_analysis/figures.py index eaed2b8d..63bd1fcf 100644 --- a/src/data_analysis/figures.py +++ b/src/data_analysis/figures.py @@ -1,13 +1,68 @@ -import os -import matplotlib as mpl -import matplotlib.pyplot as plt +from collections import Counter +import os, pickle, json + import pandas as pd import numpy as np + +import matplotlib as mpl +import matplotlib.pyplot as plt import seaborn as sns from statannotations.Annotator import Annotator + from src.utils import config as cfg from src.utils.loader import Loader +def fig0_dataPro_overlap(data:str='davis', data_root:str=cfg.DATA_ROOT, verbose=False): + data_path = f'{data_root}/{data}' + + Y = pickle.load(open(f'{data_path}/Y', "rb"), encoding='latin1') + row_i, col_i = np.where(np.isnan(Y)==False) + test_fold = json.load(open(f"{data_path}/folds/test_fold_setting1.txt")) + train_fold = json.load(open(f"{data_path}/folds/train_fold_setting1.txt")) + + # loading up train and test protein indices + train_flat = [i for fold in train_fold for i in fold] + test_protein_indices = col_i[test_fold] + train_protein_indices = col_i[train_flat] + + # Overlap in train and test... + overlap = set(train_protein_indices).intersection(set(test_protein_indices)) + if verbose: + print(f'number of unique proteins in train: {len(set(train_protein_indices))}') + print(f'number of unique proteins in test: {len(set(test_protein_indices))}') + print(f'total number of unique proteins: {max(col_i)+1}') + print(f'Intersection of train and test protein indices: {len(overlap)}') + + # counts of overlaping proteins + test_counts = Counter(test_protein_indices) + train_counts = Counter(train_protein_indices) + + overlap_test_counts = {k: test_counts[k] for k in overlap} + overlap_train_counts = {k: train_counts[k] for k in overlap} + + # normalized for set size + norm_overlap_test_counts = {k: v/len(test_protein_indices) for k,v in overlap_test_counts.items()} + norm_overlap_train_counts = {k: v/len(train_protein_indices) for k,v in overlap_train_counts.items()} + + # plot overlap counts + plt.figure(figsize=(15,10)) + plt.subplot(2,1,1) + plt.bar(overlap_train_counts.keys(), overlap_train_counts.values(), label='train', width=1.0) + plt.bar(overlap_test_counts.keys(), overlap_test_counts.values(), label='test', width=1.0) + plt.xlabel('protein index') + plt.ylabel('count') + plt.title(f'Counts of proteins in train and test ({data})') + plt.legend() + + plt.subplot(2,1,2) + plt.bar(norm_overlap_train_counts.keys(), norm_overlap_train_counts.values(), label='train', width=1.0) + plt.bar(norm_overlap_test_counts.keys(), norm_overlap_test_counts.values(), label='test', width=1.0) + plt.xlabel('protein index') + plt.ylabel('Normalized Counts') + plt.title(f'Normalized counts by dataset size of proteins in train and test ({data})') + plt.legend() + plt.tight_layout() + # Figure 1 - Protein overlap cindex difference (nomsa) def fig1_pro_overlap(df, sel_col='cindex', verbose=False, show=False, context='paper'): filtered_df = df[(df['feat'] == 'nomsa') @@ -44,6 +99,8 @@ def fig1_pro_overlap(df, sel_col='cindex', verbose=False, show=False, context='p # Show the plot if show: plt.show() + + return plot_df # Figure 2 - node feature cindex difference # Features -> nomsa, msa, shannon, and esm @@ -186,7 +243,7 @@ def fig3_edge_feat(df, verbose=False, sel_col='cindex', exclude=['af2-anm'], sho # Figure 4: violin plot with error bars for Cross-validation results to show significance among pro feats def fig4_pro_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex', exclude=[], - show=False, add_labels=True, add_stats=True): + show=False, add_labels=True, add_stats=True, ax=None): # Extract relevant data filtered_df = df[(df['edge'] == 'binary') & (~df['overlap']) & (df['lig_feat'].isna())] @@ -207,7 +264,7 @@ def fig4_pro_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex # Get values for each node feature plot_data = [nomsa, msa, shannon, esm] - ax = sns.violinplot(data=plot_data) + ax = sns.violinplot(data=plot_data, ax=ax) ax.set_xticklabels(['nomsa', 'msa', 'shannon', 'esm']) ax.set_ylabel(sel_col) ax.set_xlabel('Features') @@ -230,7 +287,7 @@ def fig4_pro_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex # Figure 5: violin plot with error bars for Cross-validation results to show significance among edge feats def fig5_edge_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex', exclude=[], - show=False, add_labels=True, add_stats=True): + show=False, add_labels=True, add_stats=True, ax=None): filtered_df = df[(df['feat'] == 'nomsa') & (~df['overlap']) & (df['lig_feat'].isna())] filtered_df = filtered_df[(filtered_df['data'] == sel_dataset) & (filtered_df['fold'] != '')] @@ -244,7 +301,7 @@ def fig5_edge_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cinde # plot violin plot with annotations plot_data = [binary, simple, anm, af2] - ax = sns.violinplot(data=plot_data) + ax = sns.violinplot(data=plot_data, ax=ax) ax.set_xticklabels(['binary', 'simple', 'anm', 'af2']) ax.set_ylabel(sel_col) ax.set_xlabel('Edge type') @@ -285,6 +342,42 @@ def fig6_protein_appearance(datasets=['kiba', 'PDBbind'], show=False): plt.tight_layout() if show: plt.show() + + +def fig_combined(df, datasets=['PDBbind','davis', 'kiba'], metrics=['cindex', 'mse'], + fig_callable=fig4_pro_feat_violin, + show=False, **kwargs): + # Create subplots with datasets as columns and cols as rows + fig, axes = plt.subplots(len(metrics), len(datasets), + figsize=(5*len(datasets), 4*len(metrics))) + for i, dataset in enumerate(datasets): + for j, metric in enumerate(metrics): + # Set current subplot + ax = axes[j, i] + + fig_callable(df, sel_col=metric, sel_dataset=dataset, show=False, + ax=ax, **kwargs) + + # Add titles only to the top row and left column + if j == 0: + ax.set_title(f'{dataset}') + ax.set_xlabel('') + ax.set_xticklabels([]) + elif j < len(metrics)-1: # middle row + ax.set_xlabel('') + ax.set_xticklabels([]) + ax.set_title('') + else: # bottom row + ax.set_title('') + + if i == 0: + ax.set_ylabel(metric) + else: + ax.set_ylabel('') + + plt.tight_layout() # Adjust layout to prevent clipping of titles + if show: plt.show() + return fig, axes def prepare_df(csv_p:str=cfg.MODEL_STATS_CSV, old_csv_p:str=None) -> pd.DataFrame: """ @@ -380,4 +473,9 @@ def prepare_df(csv_p:str=cfg.MODEL_STATS_CSV, old_csv_p:str=None) -> pd.DataFram plt.savefig(f"results/figures/fig5_edge_feat_violin_{dataset}_{col}.png", dpi=300, bbox_inches='tight') plt.clf() - # %% + #%% dataset comparisons + plot_df = fig1_pro_overlap(df, sel_col='mse', verbose=verbose, show=False) + + # performance drop values: + grp = plot_df.groupby(['data', 'overlap']).mean() + grp[grp.index.get_level_values(1)].values - grp[~grp.index.get_level_values(1)].values