-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathinterpretation_fp.py
63 lines (52 loc) · 2.05 KB
/
interpretation_fp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import csv
import numpy as np
from fpgnn.tool import set_interfp_argument, set_log, get_scaler, load_args, load_data, load_model, rmse
from fpgnn.train import predict
from fpgnn.data import MoleDataSet
def make_fp_interpretation(args,log):
info = log.info
info('Load args.')
scaler = get_scaler(args.model_path)
train_args = load_args(args.model_path)
for key,value in vars(train_args).items():
if not hasattr(args, key):
setattr(args, key, value)
info('Load data.')
test_data = load_data(args.predict_path,args)
test_label = test_data.label()
test_label = np.squeeze(np.array(test_label))
result = []
orig_score = 0
if hasattr(args,'fp_type'):
fp_type = args.fp_type
else:
fp_type = 'mixed'
if fp_type == 'mixed':
fp_length = 1490
else:
fp_length = 1025
for fp_changebit in range(fp_length):# 0:nothing changed 1-x:changed bit
args.fp_changebit = fp_changebit
model = load_model(args.model_path,args.cuda,pred_args=args)
model_pred = predict(model,test_data,args.batch_size,scaler)
model_pred = np.array(model_pred)
if fp_changebit == 0:
info('Original fingerprint. Nothing changed.')
orig_score = rmse(test_label,model_pred)
else:
info(f'Change fingerprint bit : {fp_changebit}')
change_score = rmse(test_label,model_pred)
res = orig_score - change_score
info(f'Change Importance: {res}')
result.append([fp_changebit,res])
with open(args.result_path,'w',newline = '') as file:
writer = csv.writer(file)
line = ['No_of_Bit_Changed','Importance']
writer.writerow(line)
for i in range(len(result)):
line = result[i]
writer.writerow(line)
if __name__ == '__main__':
args = set_interfp_argument()
log = set_log('inter_fp',args.log_path)
make_fp_interpretation(args,log)