-
Notifications
You must be signed in to change notification settings - Fork 2
/
show_results.py
37 lines (30 loc) · 8.78 KB
/
show_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"""
In this script we format the results from experiment 1 and format them as a table.
We simply copied the output of the run_experiment_1.py script here.
"""
import numpy as np
import pandas as pd
# This is the output generated by running run_experiment_1.py
experiment_results = {
'original': {'precision': [0.8913881748071979, 0.8775765146783261, 0.8790523690773068, 0.8659485338120886, 0.881795195954488, 0.8761006289308176, 0.8790170132325141, 0.8755472170106317, 0.8905440414507773, 0.8623232944068838], 'recall': [0.9222074468085106, 0.9341755319148937, 0.9375, 0.9621010638297872, 0.9275265957446809, 0.9261968085106383, 0.9275265957446809, 0.9308510638297872, 0.9142287234042553, 0.932845744680851], 'auroc': [0.7907408201784489, 0.7695071207961566, 0.7731854838709677, 0.7552440803019904, 0.7752552333562115, 0.7645096945778997, 0.7702149107755663, 0.7648206932052162, 0.7867514584763212, 0.7406164207275222], 'dp': [0.2198180594603273, 0.1685498689353926, 0.1737229362108511, 0.12423640418712789, 0.16744183898277687, 0.1760683760683761, 0.17452760106433274, 0.1939568548186722, 0.22016636865427974, 0.17318162020905925], 'ftu': [0.02650000000000008, 0.013499999999999956, 0.02949999999999997, 0.0030000000000000027, 0.024499999999999966, 0.007499999999999951, 0.015499999999999958, 0.03200000000000003, 0.039999999999999925, 0.03650000000000009]},
'vanilla_gan': {'precision': [0.6560509554140127, 0.8007202881152461, 0.8686440677966102, 0.8620689655172413, 0.8157099697885196, 0.7248787248787248, 0.7569444444444444, 0.7904761904761904, 0.5555555555555556, 0.8888888888888888], 'recall': [0.13715046604527298, 0.4440745672436751, 0.13648468708388814, 0.2829560585885486, 0.5392809587217043, 0.6964047936085219, 0.7982689747003995, 0.11051930758988016, 0.043275632490013316, 0.255659121171771], 'auroc': [0.4601414980828774, 0.555370616955171, 0.5371178455499762, 0.5732049369247963, 0.585905539601816, 0.4496080192942208, 0.512588302611244, 0.5110829469676309, 0.4694289809036412, 0.5796367895015482], 'dp': [ 0.12390500784849867, 0.3274573165960072, 0.11442696753345394, 0.027004819576595368, 0.3181588020567209, 0.1327639143984276, 0.1604315983778235, 0.11050041659186426, 0.03756922099622997, 0.6679095374219177], 'ftu': [0.04500000000000001, 0.314, 0.007000000000000006, 0.1555, 0.09850000000000003, 0.21999999999999997, 0.124, 0.2435, 0.08499999999999999, 0.7295]},
'wgan_gp': {'precision': [0.8655589123867069, 0.6837209302325581, 0.6888297872340425, 0.7236220472440945, 0.7591564927857936, 0.7501323451561673, 0.8954248366013072, 0.7950481430536451, 0.883495145631068, 0.7954815695600476], 'recall': [0.381491344873502, 0.5872170439414115, 0.3448735019973369, 0.6118508655126498, 0.45539280958721706, 0.9434087882822902, 0.27363515312916115, 0.3848202396804261, 0.24234354194407456, 0.44540612516644473], 'auroc': [0.6013882427178755, 0.3839699677538383, 0.4374969919625238, 0.453515794202108, 0.5098249188498335, 0.49780881181182784, 0.5886248054802432, 0.5428117262659159, 0.5729789998876998, 0.5500122995310135], 'dp': [0.5934063404821418, 0.03799962253554845, 0.052679307122570074, 0.14299918522917154, 0.24493530167235167, 0.06120217824607921, 0.17395357187245383, 0.3891727544317549, 0.29822453610999866, 0.31135983870299533], 'ftu': [0.5775, 0.010000000000000009, 0.118, 0.01649999999999996, 0.2945, 0.04299999999999993, 0.008000000000000007, 0.39499999999999996, 0.2585, 0.35950000000000004]},
'fairgan': {'precision': [0.8875, 0.8037486218302095, 0.9038613081166272, 0.7976318622174381, 0.8315217391304348, 0.8997772828507795, 0.8070273284997211, 0.7955043859649122, 0.8231884057971014, 0.7959073774905762], 'recall': [0.8035952063914781, 0.9707057256990679, 0.7636484687083888, 0.9866844207723036, 0.9167776298268975, 0.8069241011984021, 0.9633821571238349, 0.9660452729693741, 0.9454061251664447, 0.9840213049267643], 'auroc': [0.7481831453644137, 0.6279231439740319, 0.759334274514837, 0.6158321702255093, 0.678268332985738, 0.7679198819245021, 0.6343015203289857, 0.6085246445168155, 0.6664781629848289, 0.6114885641099685], 'dp': [0.27656175916847336, 0.02148325116576677, 0.4998849193745138, 0.08584554338769745, 0.056932687040540664, 0.3597581465574782, 0.05301304093648018, 0.06017335745423236, 0.10063800698769554, 0.05205557013243478], 'ftu': [0.136, 0.055499999999999994, 0.40049999999999997, 0.05449999999999999, 0.08599999999999997, 0.1645000000000001, 0.033499999999999974, 0.009000000000000008, 0.02200000000000002, 0.009000000000000008]},
'decaf_nd': {'precision': [0.8795454545454545, 0.902200488997555, 0.9344703770197487, 0.8675213675213675, 0.8573307034845496, 0.8711656441717791, 0.911620294599018, 0.8949771689497716, 0.8922480620155039, 0.8571428571428571], 'recall': [0.7719414893617021, 0.7360372340425532, 0.692154255319149, 0.8098404255319149, 0.8670212765957447, 0.7553191489361702, 0.7406914893617021, 0.7819148936170213, 0.7652925531914894, 0.8218085106382979], 'auroc': [0.7256884866163349, 0.7470508750857927, 0.7724884179821551, 0.7174202127659575, 0.7147606382978723, 0.7083047357584078, 0.7614747769389156, 0.7518445435827043, 0.7425253088538091, 0.7032429649965682], 'dp': [0.34153277, 0.41648984, 0.45542848, 0.2933951, 0.2351135, 0.36700368, 0.37226972, 0.36494797, 0.4372166, 0.24334955], 'ftu': [0.069000006, 0.1365, 0.096000016, 0.1365, 0.057500005, 0.07999998, 0.14350003, 0.11800003, 0.19999999, 0.102]},
'decaf_dp': {'precision': [0.7470253491981376, 0.7531446540880503, 0.7562296858071506, 0.751389590702375, 0.7515243902439024, 0.7601646937725167, 0.751503006012024, 0.751389590702375, 0.7546412443552434, 0.751131221719457], 'recall': [0.9601063829787234, 0.9554521276595744, 0.9281914893617021, 0.9886968085106383, 0.9833776595744681, 0.9820478723404256, 0.9973404255319149, 0.9886968085106383, 1.0, 0.9933510638297872], 'auroc': [0.48710964310226496, 0.502927676733013, 0.5104667124227866, 0.4983806623198353, 0.4987452814001373, 0.5212658716540838, 0.49867021276595747, 0.4983806623198353, 0.5070564516129032, 0.49768359643102267], 'dp': [0.018667638, 0.020073175, 0.029135108, 0.005097389, 0.009011865, 0.02352941, 0.003889978, 0.00210315, 0.0029906034, 0.007862747], 'ftu': [0.0015000105, 0.023500025, 0.0055000186, 0.004000008, 0.0029999614, 0.013500035, 0.004999995, 0.00050002337, 0.0, 0.006000042]},
'decaf_cf': {'precision': [0.7657111356119074, 0.7944890929965557, 0.7893825735718407, 0.7598752598752598, 0.7666839110191412, 0.7655426765015806, 0.7594142259414226, 0.7663157894736842, 0.764859228362878, 0.7556848228450556], 'recall': [0.9235372340425532, 0.9202127659574468, 0.9095744680851063, 0.9720744680851063, 0.9853723404255319, 0.9660904255319149, 0.9654255319148937, 0.9680851063829787, 0.9753989361702128, 0.9501329787234043], 'auroc': [0.5333411976664378, 0.5992192862045298, 0.5868436856554564, 0.5203114275909403, 0.5380490734385724, 0.5344565030885381, 0.519003088538092, 0.5364619080301991, 0.5330623713109128, 0.5093406829100893], 'dp': [0.024875283, 0.0672586, 0.07143378, 0.0011213422, 0.024249434, 0.016911745, 0.0125283, 0.0090456605, 0.0073554516, 0.026818871], 'ftu': [0.051999986, 0.0255, 0.0059999824, 0.011500001, 0.010999978, 0.059000015, 0.0065000057, 0.012499988, 0.0055000186, 0.027499974]},
'decaf_ftu': {'precision': [0.8587650816181689, 0.8972659486329743, 0.9059011164274322, 0.8229695431472082, 0.847307430129516, 0.8913560666137986, 0.8558495821727019, 0.8260300850228908, 0.8826151560178306, 0.8694444444444445], 'recall': [0.8045212765957447, 0.7200797872340425, 0.7553191489361702, 0.8623670212765957, 0.8264627659574468, 0.7473404255319149, 0.817154255319149, 0.8397606382978723, 0.7898936170212766, 0.8324468085106383], 'auroc': [0.701655799588195, 0.7350398936170213, 0.7587079615648593, 0.6499335106382979, 0.6874249313658201, 0.7355653740562802, 0.6999077728208648, 0.6517351578586136, 0.7356726149622511, 0.7267072752230612], 'dp': [0.254201, 0.33026803, 0.38632923, 0.09094083, 0.22282213, 0.3319853, 0.22539997, 0.18188602, 0.35021257, 0.22198081], 'ftu': [0.05799997, 0.0625, 0.033000052, 0.037999988, 0.050500035, 0.061499953, 0.004499972, 0.060000002, 0.0255, 0.015999973]}
}
def display_results(results):
return f'{np.round(np.mean(results), 3)}±{np.round(np.std(results), 3)}'
df = pd.DataFrame(np.zeros((8, 6)),
columns=['model', 'precision', 'recall', 'auroc', 'ftu', 'dp'])
df['model'] = ['original', 'vanilla_gan', 'wgan_gp', 'fairgan',
'decaf_nd', 'decaf_ftu', 'decaf_cf', 'decaf_dp']
for model, model_results in experiment_results.items():
df.loc[df['model'] == model, 'precision'] = display_results(model_results['precision'])
df.loc[df['model'] == model, 'recall'] = display_results(model_results['recall'])
df.loc[df['model'] == model, 'auroc'] = display_results(model_results['auroc'])
df.loc[df['model'] == model, 'dp'] = display_results(model_results['dp'])
df.loc[df['model'] == model, 'ftu'] = display_results(model_results['ftu'])
print(df)