-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
131 lines (104 loc) · 6.06 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import pandas as pd
from scripts.calc_num_parameters import process_data
from scripts.calc_positive_pct_per_task import calculate_pct_positive_class
from scripts.plot_performances import plot_performance_ranges, plot_molformer_by_size, plot_chemberta_2, \
plot_performance_by_representation_or_objectives
def generate_figure_2_and_supp_figure_7():
data_path = os.path.join('data', 'performance_comparison')
source_reimplemented = set()
source_all = set()
transformers_dict = {}
ml_dict_comparable = {}
dl_dict_comparable = {}
ml_dict_all = {}
dl_dict_all = {}
for suffix in ['classification', 'regression']:
transformers_performance = pd.read_csv(os.path.join(data_path, f'transformers_{suffix}.csv'), index_col=0)
ml_performance = pd.read_csv(os.path.join(data_path, f'ml_{suffix}.csv'), index_col=0)
dl_performance = pd.read_csv(os.path.join(data_path, f'dl_{suffix}.csv'), index_col=0)
# The ml/dl models are copied from the transformer articles. Therefore, the first col corresponds to the
# transformer model name and the second col corresponds to the ml/dl model name. In the below lines,
# we concatenate the two cols to avoid duplicated indices when plotting. But before this, we add the original
# index column as a normal column because it will be needed.
transformers_performance['source_transformer'] = [idx.split(' ')[0] for idx in transformers_performance.index]
ml_performance['source_transformer'] = ml_performance.index
dl_performance['source_transformer'] = dl_performance.index
ml_performance.index = ml_performance.index + '_' + ml_performance.iloc[:, 0]
dl_performance.index = dl_performance.index + '_' + dl_performance.iloc[:, 0]
transformers_dict[suffix] = transformers_performance
# data for figure 2
ml_performance_comparable = (ml_performance[ml_performance['reporting'] != 'copied']
.drop([ml_performance.columns[0], 'reporting'], axis=1))
dl_performance_comparable = (dl_performance[dl_performance['reporting'] != 'copied']
.drop([dl_performance.columns[0], 'reporting'], axis=1))
source_reimplemented.update(ml_performance_comparable['source_transformer'].to_list())
source_reimplemented.update(dl_performance_comparable['source_transformer'].to_list())
ml_dict_comparable[suffix] = ml_performance_comparable
dl_dict_comparable[suffix] = dl_performance_comparable
# data for supplementary figure 7
ml_performance_all = (ml_performance.drop([ml_performance.columns[0], 'reporting'], axis=1))
dl_performance_all = (dl_performance.drop([dl_performance.columns[0], 'reporting'], axis=1))
source_all.update(ml_performance_all['source_transformer'].to_list())
source_all.update(dl_performance_all['source_transformer'].to_list())
ml_dict_all[suffix] = ml_performance_all
dl_dict_all[suffix] = dl_performance_all
# assigning a color per model for plotting
colors = ['#f0e442', '#009e73', 'rosybrown', '#0072b2', 'orange', '#0072b2', '#cc79a7', 'deepskyblue', 'magenta',
'darkviolet', 'lightcoral', 'indianred']
models_colors_all = {model: colors[idx] for idx, model in enumerate(sorted(source_all))}
models_colors_comparable = {model: models_colors_all[model] for model in sorted(source_reimplemented)}
# Plot figure 2
plot_performance_ranges(transformers_dict, ml_dict_comparable, dl_dict_comparable, data_path,
models_colors_comparable, comparable_only=True)
# Plot supplementary figure 7
plot_performance_ranges(transformers_dict, ml_dict_all, dl_dict_all, data_path, models_colors_all,
comparable_only=False)
return data_path
def generate_figure_4():
"""
This figure is generated by concatenating the below two figures in keynote.
:return:
"""
data_path = os.path.join('data', 'pretrain_dataset_size')
plot_molformer_by_size(data_path)
plot_chemberta_2(data_path)
return data_path
def generate_figure_5():
"""
This figure is generated by concatenating the below two figures in keynote.
:return:
"""
data_path = os.path.join('data', 'representation')
plot_performance_by_representation_or_objectives(data_path, mol_bert=True, objectives=False)
plot_performance_by_representation_or_objectives(data_path, mat=True, objectives=False)
return data_path
def generate_figure_6_and_supp_figures_8_n_9():
data_path = os.path.join('data', 'objectives')
plot_performance_by_representation_or_objectives(data_path, mat=True, objectives=True)
plot_performance_by_representation_or_objectives(data_path, molbert=True, objectives=True)
plot_performance_by_representation_or_objectives(data_path, k_bert=True, objectives=True)
return data_path
def generate_num_parameters_in_table_6():
data_path = os.path.join('data', 'num_parameters')
shapes = pd.read_csv(os.path.join(data_path, 'transformer_models_shape.csv'), index_col=0)
process_data(shapes, data_path)
return data_path
def generate_supp_tables_11_12_n_13():
data_path = os.path.join('data', 'multi_task_data')
out_names = ['clintox', 'sider', 'tox21']
calculate_pct_positive_class(data_path, out_names)
return data_path
if __name__ == '__main__':
d_path = generate_figure_2_and_supp_figure_7()
print(f'Figures 2 and supplementary figure 7 are saved to {d_path}')
d_path = generate_figure_4()
print(f'Individual subfigures of figure 4 are saved to {d_path}')
d_path = generate_figure_5()
print(f'Individual subfigures of figure 4 are saved to {d_path}')
d_path = generate_figure_6_and_supp_figures_8_n_9()
print(f'Figures 6 and supplementary figures 8 and 9 are saved to {d_path}')
d_path = generate_num_parameters_in_table_6()
print(f'parameters count shown in table 6 are saved to {d_path}')
d_path = generate_supp_tables_11_12_n_13()
print(f'Tables 11, 12, and 13 are saved to {d_path}')