-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrbf_SVM.py
98 lines (80 loc) · 2.75 KB
/
rbf_SVM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from time import time
import numpy as np
import datetime
from sklearn.metrics import classification_report, roc_curve, auc
import matplotlib.pyplot as plt
import xlrd
def load_excel(path):
resArray = []
data = xlrd.open_workbook(path)
table = data.sheet_by_index(0)
for i in range(table.nrows):
line = table.row_values(i)
resArray.append(line)
x = np.array(resArray)
X = []
y = []
for i in range(len(x)):
X.append(x[i][:-1])
y.append(x[i][-1])
for i in range(len(X)):
if y[i] == '0.0':
y[i] = 0
else:
y[i] = 1
X = np.array(X)
y = np.array(y)
return X, y
path = 'D:\XML2SVM\svm\est.xlsx'
X, y = load_excel(path)
print(X.shape)
y = np.array(y)
print(y)
print(X)
X = StandardScaler().fit_transform(X)
print(X)
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, y, test_size=0.3, random_state=420)
kernel = "rbf"
time0 = time()
clf = SVC(kernel=kernel,
gamma="auto",
degree=1,
cache_size=7000 # 允许使用多大的内存MB 默认200
).fit(Xtrain, Ytrain)
print("The accuracy under kernel %s is %f" % (kernel, clf.score(Xtest, Ytest)))
print(datetime.datetime.fromtimestamp(time()-time0).strftime("%M:%S:%f"))
# rbf
# 返回在对数刻度上均匀间隔的数字
# 从-10开始取50个数到1 再把这50个数转换成对数值 默认底数为10 返回值为10的x次方
gamma_range = np.logspace(-10, 1, 50)
score = []
for i in gamma_range:
clf = SVC(kernel="rbf", gamma=i, cache_size=5000).fit(Xtrain, Ytrain)
score.append(clf.score(Xtest, Ytest))
print(max(score), gamma_range[score.index(max(score))])
print(classification_report(Ytrain, clf.predict(Xtrain)))
def print_roc(Ytest, Xtest):
# https://blog.csdn.net/qq_45769063/article/details/106649523
fpr, tpr, threshold = roc_curve(Ytest, clf.predict(Xtest)) # 计算真正率和假正率
print(fpr)
print(tpr)
print(threshold)
roc_auc = auc(fpr, tpr)
plt.figure()
lw = 2
plt.figure(figsize=(10,10))
plt.plot(fpr, tpr, color='darkorange',
lw=lw, label='ROC curve (area = %0.2f)' % roc_auc) ###假正率为横坐标,真正率为纵坐标做曲线
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()
return roc_auc
roc_auc = print_roc(Ytest, Xtest)