Skip to content

Commit

Permalink
fix(figures): resolves #61
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub committed Nov 22, 2023
1 parent 9ac53be commit b46897f
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 23 deletions.
38 changes: 32 additions & 6 deletions playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,41 @@
from src.data_analysis.figures import fig2_pro_feat, fig3_edge_feat, prepare_df, fig4_pro_feat_violin, fig5_edge_feat_violin

df = prepare_df('results/model_media/model_stats.csv')
data = 'PDBbind'
data = 'kiba'
verbose=False

pdb_f = fig4_pro_feat_violin(df, sel_col='mse', sel_dataset='PDBbind', verbose=verbose)
kiba_e = fig5_edge_feat_violin(df, sel_col='mse', sel_dataset='kiba', verbose=verbose)

# %% simplified plot for debugging
from seaborn import violinplot
from matplotlib import pyplot as plt
import pandas as pd
from statannotations.Annotator import Annotator

new_df = pd.DataFrame({'binary': kiba_e[0], 'simple': kiba_e[1], 'anm': kiba_e[2], 'af2': kiba_e[3]})
new_df

#%%
ax = violinplot(data=kiba_e)
ax.set_xticklabels(['binary', 'simple', 'anm', 'a2'])

pairs = [(0,1), (0,2), (0,3), (1,2), (1,3), (2,3)]
annotator = Annotator(ax, pairs, data=kiba_e, verbose=verbose)
annotator.configure(test='Mann-Whitney', text_format='star', loc='inside',
hide_non_significant=not verbose)
annotator.apply_and_annotate()



# %%
fig4_pro_feat_violin(df, sel_col='cindex', sel_dataset=data)
fig4_pro_feat_violin(df, sel_col='mse', sel_dataset=data)
fig4_pro_feat_violin(df, sel_col='cindex', sel_dataset=data, verbose=verbose)
fig4_pro_feat_violin(df, sel_col='mse', sel_dataset=data, verbose=verbose)

#%%
fig5_edge_feat_violin(df, sel_col='cindex', sel_dataset=data)
fig5_edge_feat_violin(df, sel_col='mse', sel_dataset=data)
fig5_edge_feat_violin(df, sel_col='cindex', sel_dataset=data, verbose=verbose)
fig5_edge_feat_violin(df, sel_col='mse', sel_dataset=data, verbose=verbose)

# %%
fig2_pro_feat(df, sel_col='pearson')
fig3_edge_feat(df, exclude=['af2-anm'], sel_col='pearson')
fig3_edge_feat(df, exclude=['af2-anm'], sel_col='pearson')
51 changes: 34 additions & 17 deletions src/data_analysis/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def fig2_pro_feat(df, verbose=False, sel_col='cindex', exclude=[], show=True, ad

# Add statistical annotations
pairs=[("PDBbind", "kiba"), ("PDBbind", "davis"), ("davis", "kiba")]
annotator = Annotator(ax, pairs, data=df, x='data', y='cindex', order=['PDBbind', 'davis', 'kiba'])
annotator = Annotator(ax, pairs, data=df, x='data', y='cindex', order=['PDBbind', 'davis', 'kiba'],
verbose=verbose)
annotator.configure(test='Mann-Whitney', text_format='star', loc='inside', hide_non_significant=True,
line_height=0.005, verbose=verbose)
annotator.apply_and_annotate()
Expand All @@ -155,6 +156,8 @@ def fig2_pro_feat(df, verbose=False, sel_col='cindex', exclude=[], show=True, ad

# reset stylesheet back to defaults
mpl.rcParams.update(mpl.rcParamsDefault)

return nomsa, msa, shannon, esm

# Figure 3 - Edge type cindex difference
# Edges -> binary, simple, anm, af2
Expand Down Expand Up @@ -236,49 +239,63 @@ def fig3_edge_feat(df, verbose=False, sel_col='cindex', exclude=[], show=True, a

# Add statistical annotations
pairs=[("PDBbind", "kiba"), ("PDBbind", "davis"), ("davis", "kiba")]
annotator = Annotator(ax, pairs, data=df, x='data', y='cindex', order=['PDBbind', 'davis', 'kiba'])
annotator = Annotator(ax, pairs, data=df, x='data', y='cindex', order=['PDBbind', 'davis', 'kiba'],
verbose=verbose)
annotator.configure(test='Mann-Whitney', text_format='star', loc='inside', hide_non_significant=True,
line_height=0.005, verbose=False)
line_height=0.005, verbose=verbose)
annotator.apply_and_annotate()
# Show the plot
if show:
plt.show()

# reset stylesheet back to defaults
mpl.rcParams.update(mpl.rcParamsDefault)
mpl.rcParams.update(mpl.rcParamsDefault)

return binary, simple, anm, af2, af2_anm

# Figure 4: violin plot with error bars for Cross-validation results to show significance among pro feats
def fig4_pro_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex', exclude=[],
show=True, add_labels=True, add_stats=True):
# Extract relevant data
filtered_df = df[(df['edge'] == 'binary') & (~df['overlap']) & (df['lig_feat'].isna())]

# show all with fold info
filtered_df = filtered_df[(filtered_df['data'] == sel_dataset) & (filtered_df['fold'] != '')]
nomsa = filtered_df[(filtered_df['feat'] == 'nomsa')][sel_col]
msa = filtered_df[(filtered_df['feat'] == 'msa')][sel_col]
shannon = filtered_df[(filtered_df['feat'] == 'shannon')][sel_col]
esm = filtered_df[(filtered_df['feat'] == 'ESM')][sel_col]

# printing length of each feature
if verbose:
print(f'nomsa: {len(nomsa)}')
print(f'msa: {len(msa)}')
print(f'shannon: {len(shannon)}')
print(f'esm: {len(esm)}')


# Get values for each node feature
ax = sns.violinplot(data=[nomsa, msa, shannon, esm])
plot_data = [nomsa, msa, shannon, esm]
ax = sns.violinplot(data=plot_data)
ax.set_xticklabels(['nomsa', 'msa', 'shannon', 'esm'])
ax.set_ylabel(sel_col)
ax.set_xlabel('Features')
ax.set_title(f'Feature {sel_col} for {sel_dataset}')

# Annotation for stats
if add_stats:
pairs=[('nomsa', 'msa'), ('nomsa', 'shannon'), ('msa', 'shannon')]
pairs = [(0,1), (0,2), (1,2)]
if len(esm) > 0:
pairs += [('ESM', 'nomsa'), ('ESM', 'msa'), ('ESM', 'shannon')]
annotator = Annotator(ax, pairs, data=filtered_df, x='feat', y=sel_col, verbose=verbose)
pairs += [(0,3),(1,3), (2,3)] # add esm pairs if esm is not empty
annotator = Annotator(ax, pairs, data=plot_data, verbose=verbose)
annotator.configure(test='Mann-Whitney', text_format='star', loc='inside',
hide_non_significant=not verbose)
annotator.apply_and_annotate()

if show:
plt.show()

return nomsa, msa, shannon, esm

# Figure 5: violin plot with error bars for Cross-validation results to show significance among edge feats
def fig5_edge_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex', exclude=[],
Expand All @@ -295,22 +312,24 @@ def fig5_edge_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cinde
af2 = filtered_df[filtered_df['edge'] == 'af2'][sel_col]

# plot violin plot with annotations
ax = sns.violinplot(data=[binary, simple, anm, af2])
plot_data = [binary, simple, anm, af2]
ax = sns.violinplot(data=plot_data)
ax.set_xticklabels(['binary', 'simple', 'anm', 'af2'])
ax.set_ylabel(sel_col)
ax.set_xlabel('Edge type')
ax.set_title(f'Edge type {sel_col} for {sel_dataset}')

if add_stats:
pairs = [('binary', 'simple'), ('binary', 'anm'), ('binary', 'af2'),
('simple', 'anm'), ('simple', 'af2'), ('anm', 'af2')]
annotator = Annotator(ax, pairs, data=filtered_df, x='edge', y=sel_col, verbose=verbose)
annotator.configure(test='Mann-Whitney', text_format='star', loc='inside', hide_non_significant=True)
pairs = [(0,1), (0,2), (0,3), (1,2), (1,3), (2,3)]
annotator = Annotator(ax, pairs, data=plot_data, verbose=verbose)
annotator.configure(test='Mann-Whitney', text_format='star', loc='inside',
hide_non_significant=not verbose)
annotator.apply_and_annotate()

if show:
plt.show()


return binary, simple, anm, af2


def prepare_df(csv_p:str, old_csv_p='results/model_media/old_model_stats.csv') -> pd.DataFrame:
Expand Down Expand Up @@ -343,8 +362,6 @@ def prepare_df(csv_p:str, old_csv_p='results/model_media/old_model_stats.csv') -
return df




if __name__ == '__main__':
df = prepare_df('results/model_media/model_stats.csv')

Expand Down

0 comments on commit b46897f

Please sign in to comment.