Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement RayTune for DDP #68

Closed
jyaacoub opened this issue Dec 8, 2023 · 4 comments · Fixed by #67
Closed

Implement RayTune for DDP #68

jyaacoub opened this issue Dec 8, 2023 · 4 comments · Fixed by #67

Comments

@jyaacoub
Copy link
Owner

jyaacoub commented Dec 8, 2023

Here is some results from a recent test of tuning DGraphDTA:

Trial status: 43 TERMINATED | 7 ERROR
Current time: 2023-12-07 18:16:25. Total running time: 1d 2hr 29min 55s
Logical resource usage: 6.0/26 CPUs, 1.0/3 GPUs (0.0/2.0 accelerator_type:V100, 0.0/1.0 accelerator_type:P100)
Current best trial: 5d1697a7 with mean_loss=0.8118192549794913 and params={'epochs': 15, 'model': 'DG', 'dataset': 'davis', 'feature_opt': 'nomsa', 'edge_opt': 'binary', 'fold_selection': 0, 'save_checkpoint': False, 'lr': 0.00011928463854427786, 'dropout': 0.23544603499316552, 'batch_size': 128}
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name           status                lr     dropout     batch_size       loss     iter     total time (s) │
├─────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ objective_f8c6f0ce   TERMINATED   0.00710604    0.441609             128   1.02605        15           2892.48  │
│ objective_74930ca2   TERMINATED   0.000105874   0.151808              64   0.890213       15           6666.19  │
│ objective_44424cd6   TERMINATED   0.00804006    0.48004               32   0.861907       15           8359.86  │
│ objective_5d1697a7   TERMINATED   0.000119285   0.235446             128   0.811819       15           7687.3   │
│ objective_9e9ab4a8   TERMINATED   0.00105199    0.203082              32   0.871349       15           6874.14  │
│ objective_f78f37db   TERMINATED   0.000819155   0.452793              16   0.887771       15           3995.45  │
│ objective_2c08490a   TERMINATED   0.000177966   0.191434              16   0.964595       15           6079     │
│ objective_00596c9a   TERMINATED   0.00184594    0.199897              32   0.865287       15           3482.35  │
│ objective_2262ea3d   TERMINATED   0.00136396    0.490018              16   0.994019       15           8457.83  │
│ objective_5aabac85   TERMINATED   0.00336465    0.451498              16   0.910935       15           7214.91  │
│ objective_2390d842   TERMINATED   0.00609369    0.439089              32   0.861563       15           9649.15  │
│ objective_b66fa998   TERMINATED   0.00168243    0.195735             128   0.983036       15           2913.42  │
│ objective_98701f7c   TERMINATED   0.000589411   0.193649              64   0.862975       15           6330.44  │
│ objective_00d8c0f9   TERMINATED   0.000292745   0.0469459            128   1.23125        15           2930.09  │
│ objective_2512233d   TERMINATED   0.000526634   0.115358             128   0.906146       15           7646.42  │
│ objective_e34b528a   TERMINATED   0.000332387   0.0268555            128   1.59321        15           8601.55  │
│ objective_a365ee56   TERMINATED   0.000366016   0.323265              32   0.864072       15           3685.59  │
│ objective_965ab5f3   TERMINATED   0.00866796    0.348361              32   0.861505       15           9705.65  │
│ objective_2d3a31be   TERMINATED   0.000463571   0.338539              64   0.861775       15           6381.91  │
│ objective_3e4ed857   TERMINATED   0.00030737    0.340331              32   0.862969       15           3446.03  │
│ objective_47be0589   TERMINATED   0.0039828     0.321489             128   0.997658       15           7731.87  │
│ objective_71d21903   TERMINATED   0.0037087     0.307039              32   0.874798       15           3229.29  │
│ objective_afd0eaab   TERMINATED   0.0040534     0.288038             128   0.992873       15           2806.05  │
│ objective_028b8f79   TERMINATED   0.000108837   0.396209             128   0.872468       15           9272.22  │
│ objective_475a3843   TERMINATED   0.00984875    0.365005              32   0.862079       15           6526.68  │
│ objective_4bee96c9   TERMINATED   0.00956038    0.368061              32   0.862217       15           3342.05  │
│ objective_e4ad005f   TERMINATED   0.00974355    0.393573              32   0.862119       15           3310.31  │
│ objective_725064a1   TERMINATED   0.00606494    0.397888              32   0.866959       15           9703.14  │
│ objective_24a75521   TERMINATED   0.00535432    0.269183              32   0.868425       15           9652.46  │
│ objective_a0af311c   TERMINATED   0.00225667    0.268799              64   0.913351       15           9519.35  │
│ objective_47593fa9   TERMINATED   0.00248622    0.244709             128   1.00566        15           7567.79  │
│ objective_9b980f42   TERMINATED   0.00248437    0.261983              64   0.912989       15           3784.22  │
│ objective_69c9bbaa   TERMINATED   0.00231254    0.249876              64   0.916671       15           9556.39  │
│ objective_dae66e74   TERMINATED   0.00660085    0.429673             128   0.920999       15           4782.56  │
│ objective_c7545b5b   TERMINATED   0.00641722    0.419257              64   0.890712       15           7629.53  │
│ objective_e7b6b34e   TERMINATED   0.00725481    0.420755              64   0.862382       15           4955.95  │
│ objective_36ccd8eb   TERMINATED   0.00018237    0.343186              64   0.84636        15           9503.14  │
│ objective_9e6babfd   TERMINATED   0.00524943    0.354616              64   0.890387       15           5289.6   │
│ objective_d35233be   TERMINATED   0.000874823   0.347582              64   0.861592       15           7715.35  │
│ objective_6a956028   TERMINATED   0.000188524   0.355901              64   0.946349       15           5587.46  │
│ objective_50048e57   TERMINATED   0.000904451   0.46742               32   0.868344       15           9344.52  │
│ objective_3ce9c421   TERMINATED   0.000186455   0.464254              32   0.851946       15           7718.24  │
│ objective_86747a45   TERMINATED   0.000200258   0.468154              16   1.04599        15           5483.12  │
│ objective_6f6d5d1b   ERROR        0.00321539    0.288334             128   0.862627       11           6327.9   │
│ objective_f4212379   ERROR        0.00374068    0.274769              64   0.872675       12           5223.13  │
│ objective_f33aa378   ERROR        0.00903399    0.398166              32   1.72131         7           1505.61  │
│ objective_172c8eb2   ERROR        0.00859029    0.384062              32                                        │
│ objective_42f88972   ERROR        0.0095653     0.388995              32   0.864043       14           7969.27  │
│ objective_57e7f0f2   ERROR        0.00589413    0.25981               32   2.26465         2            885.104 │
│ objective_70608908   ERROR        0.00577274    0.287196              32   0.889743        7           1495.73  │
╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

Some resulted in an ERROR due to TIMEOUTS by SLURM. The best hyperparameters were:

'lr': 0.00012, 'dropout': 0.24, 'batch_size': 128

Will try training with these to see performance improvment.


The goal is to do the same with the EDI models. This will be tricky but could be well worth.

@jyaacoub
Copy link
Owner Author

jyaacoub commented Dec 8, 2023

Getting Session info

For session information we can use ray.air.session.

This is the same as ray.train.context.TrainContext with documentation at https://docs.ray.io/en/latest/train/api/doc/ray.train.context.TrainContext.html.

Both essentially just rely on _TrainSession


Code for how ray.air.session is the same is shown here: https://github.com/ray-project/ray/blob/master/python/ray/train/_internal/session.py#L15
and https://github.com/ray-project/ray/blob/master/python/ray/air/_internal/session.py#L8.

@jyaacoub
Copy link
Owner Author

Results comparison for tuned DGM

This model has already been well tuned so I guess expecting it to perform much better is unrealistic. But we do see some improvements:

image
image

image

Code

#%% Comparing tuned and non tuned DG models
from src.data_analysis.figures import prepare_df

df = prepare_df('./results/model_media/model_stats.csv')
df_filtered = df[(df['feat'] == 'nomsa') & (df['edge'] == 'binary') & (df['data'] == 'davis') 
                & (~df['overlap']) & df['lig_feat'].isna() &
                (df['fold'] != '')]

df_grp = df_filtered.groupby(['lr', 'dropout', 'batch_size'])

#%% plot the two groups
import matplotlib.pyplot as plt
import seaborn as sns

### MSE
sns.violinplot(x='batch_size', y='mse', data=df_filtered)
plt.title('Tuned vs Non-Tuned DG Model Performance')

# chang tick names
plt.xticks([0, 1], ['Non-Tuned', 'Tuned'])
plt.xlabel('Model')
plt.show()

### CINDEX
sns.violinplot(x='batch_size', y='cindex', data=df_filtered)
plt.title('Tuned vs Non-Tuned DG Model Performance')

# change tick names
plt.xticks([0, 1], ['Non-Tuned', 'Tuned'])
plt.xlabel('Model')
plt.show()

#%%
from src.data_analysis.figures import fig2_pro_feat
import matplotlib as mpl
import numpy as np

verbose=False
sel_col='mse'
exclude=['shannon', 'msa']
show=True
add_labels=True
context='poster'

# Extract relevant data
filtered_df = df[(df['edge'] == 'binary') & (~df['overlap']) & (df['data'] == 'davis')
                    & (df['fold'] != '') & (df['lig_feat'].isna())]

# setting feat of Raytuned optimized model to 'tuned'
filtered_df.loc[(filtered_df['dropout'] == 0.24) & 
            (filtered_df['batch_size'] == 128), 'feat'] = 'tuned'


#%% get only data, feat, and sel_col columns
plot_df = filtered_df[['data', 'feat', sel_col]]
    
hue_order = ['nomsa', 'msa', 'shannon', 'ESM', 'tuned']
for f in exclude:
    plot_df = plot_df[plot_df['feat'] != f]
    if f in hue_order:
        hue_order.remove(f)   

# Create a bar plot using Seaborn
if context == 'poster':
    scale = 1
    plt.figure(figsize=(14, 7))
else:
    scale = 0.8
    plt.figure(figsize=(10, 5))

sns.set(style="darkgrid")
sns.set_context(context)
ax = sns.barplot(data=plot_df, x='feat', y=sel_col, palette='deep', estimator=np.mean,
                    order=hue_order, 
                    errcolor='gray', errwidth=2)
sns.stripplot(data=plot_df, x='feat', y=sel_col, palette='deep',
                order=hue_order,
                size=6*scale, jitter=True, dodge=True, alpha=0.7, ax=ax)
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[len(hue_order):], labels[len(hue_order):], 
            title='', loc='upper right', prop={'size': 14*scale})

if add_labels:
    for i in ax.containers: 
        ax.bar_label(i, fmt='%.3f', fontsize=13*scale, 
                        label_type='center')
        
# Set the title
ax.set_title(f'Node feature performance ({"concordance index" if sel_col == "cindex" else sel_col})',
                fontsize=16*scale)

# Set the y-axis label and limit
ax.set_ylabel(sel_col)
if sel_col == 'cindex':
    ax.set_ylim([0.5, 1])  # 0.5 is the worst cindex value
if sel_col == 'pearson':
    ax.set_ylim([0, 1])

# Show the plot
if show:
    plt.show()
    
# reset stylesheet back to defaults
mpl.rcParams.update(mpl.rcParamsDefault)

@jyaacoub
Copy link
Owner Author

Based on the Ray Docs and this discussion question. It seems that the only way to have a distributed environment work with RayTune is to use their Train module....

@jyaacoub
Copy link
Owner Author

Closing this issue as completed since it works now.

Logical resource usage: 9.0/32 CPUs, 4.0/12 GPUs (0.0/3.0 accelerator_type:V100)
Current best trial: dc87c6ae with loss=0.6445849582634336 and params={
'train_loop_config': {
    'epochs': 15, 'model': 'SPD', 'dataset': 'davis', 'feature_opt': 'foldseek', 'edge_opt': 'binary', 
    'fold_selection': 0, 'save_checkpoint': False, 

    'lr': 0.00010346190651951206, 'batch_size': 8, 'dropout': 0.16783174788661628, 
    'dropout_prot': 0.0, 'pro_emb_dim': 480, 'extra_profc_layer': True
}, 
'scaling_config': {
    'trainer_resources': None, 'num_workers': 4, 'use_gpu': True, 
    'resources_per_worker': {'CPU': 2, 'GPU': 1}, 'placement_strategy': 'PACK'}
}
╭────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name              status         train_loop_config/lr     ...config/batch_size     ...op_config/dropout     ...onfig/pro_emb_dim   ...extra_profc_layer       iter     total time (s)       loss │
├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ TorchTrainer_4b9aa4e4   TERMINATED              0.00159805                         8              0.45743                          480   True                         15            3644.25   0.735962 │
│ TorchTrainer_7b7671c2   TERMINATED              0.000140605                       16              0.411237                        1024   True                         15            4145.42   0.776585 │
│ TorchTrainer_90ddc4c5   TERMINATED              0.000291404                       16              0.417351                         512   False                        15            3752.74   0.782918 │
│ TorchTrainer_4cdd2035   TERMINATED              0.00275608                         4              0.319747                         512   False                        15            3676.78   0.993408 │
│ TorchTrainer_ce47f94f   TERMINATED              0.00120235                        12              0.217379                         480   False                        15            3688.89   0.7571   │
│ TorchTrainer_c1c78aa6   TERMINATED              0.00636736                         4              0.0668473                       1024   False                        15            4179.49   0.734024 │                                                                                                                                                
│ TorchTrainer_cf3dd845   TERMINATED              0.000113761                       16              0.0952068                        480   True                         15            3736.77   0.666004 │
│ TorchTrainer_f38c81ec   TERMINATED              0.00567628                        16              0.117595                         480   False                        15            3740.47   0.794086 │
│ TorchTrainer_d1acbf8c   TERMINATED              0.000579514                        8              0.0731969                        512   True                         15            3674.74   0.732458 │
│ TorchTrainer_c7c5cc43   TERMINATED              0.00507947                         4              0.434002                         512   True                         15            3706.88   0.735434 │
│ TorchTrainer_0aaebea1   TERMINATED              0.00400581                        16              0.429757                         512   True                         15            3751.79   0.734037 │
│ TorchTrainer_467a7661   TERMINATED              0.000106702                        8              0.242172                         512   True                         15            3693.5    0.708856 │
│ TorchTrainer_d6ce607c   TERMINATED              0.000100618                       12              0.176796                         480   True                         15            3692.86   0.663124 │
│ TorchTrainer_d2dbc636   TERMINATED              0.000485812                        8              0.000471847                      480   True                         15            3791.81   0.740899 │
│ TorchTrainer_def9c69e   TERMINATED              0.000447722                        8              0.0164498                        480   True                         15            3665.25   0.737996 │
│ TorchTrainer_dc87c6ae   TERMINATED              0.000103462                        8              0.167832                         480   True                         15            3649.84   0.644585 │
│ TorchTrainer_7874c645   TERMINATED              0.000194073                       12              0.0209678                        480   True                         15            3832.09   0.721365 │
│ TorchTrainer_51a39f59   TERMINATED              0.000206814                       12              0.169414                         480   True                         15            3704.9    0.693014 │
│ TorchTrainer_0ef3667a   TERMINATED              0.000149908                       12              0.168351                         480   True                         15            3691.68   0.65697  │                                                                                                                                                
│ TorchTrainer_fcf66391   TERMINATED              0.000206367                       12              0.167047                         480   True                         15            3830.24   0.699192 │
│ TorchTrainer_25ed1e9b   TERMINATED              0.000217083                       12              0.17423                         1024   True                         15            4093.24   0.764177 │
│ TorchTrainer_dbc5b1b3   TERMINATED              0.0002815                         12              0.161344                        1024   True                         15            4080.07   0.758802 │
│ TorchTrainer_774996b0   TERMINATED              0.00027969                        12              0.282404                        1024   True                         15            4222.86   0.760099 │
│ TorchTrainer_889d961b   TERMINATED              0.000316104                        8              0.296708                        1024   True                         15            4065.56   0.732606 │
│ TorchTrainer_1d437011   TERMINATED              0.000107342                       12              0.293824                         480   True                         15            3697.12   0.680446 │
│ TorchTrainer_9b96d5c6   TERMINATED              0.000105715                       12              0.290922                         480   True                         15            3831.72   0.67241  │                                                                                                                                                
│ TorchTrainer_8e1e38ab   TERMINATED              0.000104393                       12              0.203813                         480   True                         15            3704.57   0.675928 │
│ TorchTrainer_254b7489   TERMINATED              0.000100547                       12              0.213738                         480   True                         15            3693.03   0.664182 │
│ TorchTrainer_7f0984b4   TERMINATED              0.000146202                       12              0.220755                         480   True                         15            3831.34   0.684478 │
│ TorchTrainer_ad3ee4c2   TERMINATED              0.000153666                       12              0.21261                          480   True                         15            3708.86   0.691205 │
│ TorchTrainer_ae85d777   TERMINATED              0.000151375                        8              0.129261                         480   False                        15            3636.05   0.671802 │
│ TorchTrainer_c0266485   TERMINATED              0.000170168                        8              0.138339                         480   False                        15            3772.02   0.695039 │
│ TorchTrainer_53a98d1b   TERMINATED              0.000157354                        8              0.141885                         480   False                        15            3655.56   0.712433 │
│ TorchTrainer_072e10b0   TERMINATED              0.000166863                        8              0.13428                          480   False                        15            3641.66   0.699354 │
│ TorchTrainer_10c67a1a   TERMINATED              0.000152504                       12              0.138722                         480   True                         15            3830.68   0.653146 │
│ TorchTrainer_d7422ad2   TERMINATED              0.000136832                       12              0.188325                         480   True                         15            3704.42   0.660363 │                                                                                                                                                
│ TorchTrainer_831b418d   TERMINATED              0.000130758                       12              0.190109                         480   True                         15            3693.59   0.669827 │                                                                                                                                                
│ TorchTrainer_fbe28830   TERMINATED              0.000128652                       12              0.187059                         480   True                         15            3830.62   0.661278 │                                                                                                                                                
│ TorchTrainer_dca9d926   TERMINATED              0.000142329                        4              0.187723                         480   True                         15            3673.25   0.834511 │                                                                                                                                                
│ TorchTrainer_f640265e   TERMINATED              0.000829701                        4              0.252758                         480   True                         15            3635.52   0.749894 │                                                                                                                                                
│ TorchTrainer_f126772f   TERMINATED              0.000246025                        4              0.248288                         480   True                         15            3764.2    0.971605 │                                                                                                                                                
│ TorchTrainer_e8794ea4   TERMINATED              0.000242297                        4              0.244233                         480   True                         15            3672.73   0.943737 │                                                                                                                                                
│ TorchTrainer_73018870   TERMINATED              0.00177218                        16              0.104248                         480   True                         15            3735.49   0.754655 │                                                                                                                                                
│ TorchTrainer_2443df76   TERMINATED              0.000226738                       16              0.107037                         512   True                         15            3894.22   0.685572 │                                                                                                                                                
│ TorchTrainer_facb29ef   TERMINATED              0.000131008                       12              0.10441                          480   True                         15            3704.88   0.654693 │
│ TorchTrainer_80dcdafb   TERMINATED              0.000128984                       16              0.106138                         480   True                         15            3735.25   0.650588 │                                                                                                                                                
│ TorchTrainer_ca1902ff   TERMINATED              0.000127723                       12              0.153881                         512   True                         15            3725.63   0.647233 │                                                                                                                                                
│ TorchTrainer_2b285a3c   TERMINATED              0.000338105                       12              0.10172                          512   True                         15            3851.49   0.750991 │                                                                                                                                                
│ TorchTrainer_a43d96ed   TERMINATED              0.000130856                       12              0.0829247                        512   True                         15            3713.21   0.656583 │                                                                                                                                                
│ TorchTrainer_a1976fbb   TERMINATED              0.000184014                       16              0.087322                         512   True                         15            3762.67   0.664556 │                                                                                                                                                
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

@jyaacoub jyaacoub linked a pull request Dec 18, 2023 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant