-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement RayTune for DDP #68
Comments
Getting Session infoFor session information we can use This is the same as Both essentially just rely on Code for how ray.air.session is the same is shown here: https://github.com/ray-project/ray/blob/master/python/ray/train/_internal/session.py#L15 |
Results comparison for tuned DGMThis model has already been well tuned so I guess expecting it to perform much better is unrealistic. But we do see some improvements: Code#%% Comparing tuned and non tuned DG models
from src.data_analysis.figures import prepare_df
df = prepare_df('./results/model_media/model_stats.csv')
df_filtered = df[(df['feat'] == 'nomsa') & (df['edge'] == 'binary') & (df['data'] == 'davis')
& (~df['overlap']) & df['lig_feat'].isna() &
(df['fold'] != '')]
df_grp = df_filtered.groupby(['lr', 'dropout', 'batch_size'])
#%% plot the two groups
import matplotlib.pyplot as plt
import seaborn as sns
### MSE
sns.violinplot(x='batch_size', y='mse', data=df_filtered)
plt.title('Tuned vs Non-Tuned DG Model Performance')
# chang tick names
plt.xticks([0, 1], ['Non-Tuned', 'Tuned'])
plt.xlabel('Model')
plt.show()
### CINDEX
sns.violinplot(x='batch_size', y='cindex', data=df_filtered)
plt.title('Tuned vs Non-Tuned DG Model Performance')
# change tick names
plt.xticks([0, 1], ['Non-Tuned', 'Tuned'])
plt.xlabel('Model')
plt.show()
#%%
from src.data_analysis.figures import fig2_pro_feat
import matplotlib as mpl
import numpy as np
verbose=False
sel_col='mse'
exclude=['shannon', 'msa']
show=True
add_labels=True
context='poster'
# Extract relevant data
filtered_df = df[(df['edge'] == 'binary') & (~df['overlap']) & (df['data'] == 'davis')
& (df['fold'] != '') & (df['lig_feat'].isna())]
# setting feat of Raytuned optimized model to 'tuned'
filtered_df.loc[(filtered_df['dropout'] == 0.24) &
(filtered_df['batch_size'] == 128), 'feat'] = 'tuned'
#%% get only data, feat, and sel_col columns
plot_df = filtered_df[['data', 'feat', sel_col]]
hue_order = ['nomsa', 'msa', 'shannon', 'ESM', 'tuned']
for f in exclude:
plot_df = plot_df[plot_df['feat'] != f]
if f in hue_order:
hue_order.remove(f)
# Create a bar plot using Seaborn
if context == 'poster':
scale = 1
plt.figure(figsize=(14, 7))
else:
scale = 0.8
plt.figure(figsize=(10, 5))
sns.set(style="darkgrid")
sns.set_context(context)
ax = sns.barplot(data=plot_df, x='feat', y=sel_col, palette='deep', estimator=np.mean,
order=hue_order,
errcolor='gray', errwidth=2)
sns.stripplot(data=plot_df, x='feat', y=sel_col, palette='deep',
order=hue_order,
size=6*scale, jitter=True, dodge=True, alpha=0.7, ax=ax)
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[len(hue_order):], labels[len(hue_order):],
title='', loc='upper right', prop={'size': 14*scale})
if add_labels:
for i in ax.containers:
ax.bar_label(i, fmt='%.3f', fontsize=13*scale,
label_type='center')
# Set the title
ax.set_title(f'Node feature performance ({"concordance index" if sel_col == "cindex" else sel_col})',
fontsize=16*scale)
# Set the y-axis label and limit
ax.set_ylabel(sel_col)
if sel_col == 'cindex':
ax.set_ylim([0.5, 1]) # 0.5 is the worst cindex value
if sel_col == 'pearson':
ax.set_ylim([0, 1])
# Show the plot
if show:
plt.show()
# reset stylesheet back to defaults
mpl.rcParams.update(mpl.rcParamsDefault) |
Based on the Ray Docs and this discussion question. It seems that the only way to have a distributed environment work with RayTune is to use their Train module.... |
Closing this issue as completed since it works now.
|
Here is some results from a recent test of tuning DGraphDTA:
Some resulted in an ERROR due to TIMEOUTS by SLURM. The best hyperparameters were:
Will try training with these to see performance improvment.
The goal is to do the same with the EDI models. This will be tricky but could be well worth.
The text was updated successfully, but these errors were encountered: