-
Notifications
You must be signed in to change notification settings - Fork 0
/
LinearSVM.py
40 lines (34 loc) · 1.16 KB
/
LinearSVM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# -*- coding: utf-8 -*-
# @File : LinearSVM.py
# @Author : Hua Guo
# @Time : 2021/9/10 下午10:34
# @Disc : Inspired from https://github.com/cperales/SupportVectorMachine
import numpy as np
from src.Model.SVM.solver import fit_soft, fit
from src.Model.SVM import BaseSVM
class LinearSVM(BaseSVM):
def __init__(self):
self.w = None
self.b = None
def fit(self, X, y, soft=False):
y = self.change_label(y, replace={0.:-1.})
# change datatype from int to float
X = X.astype('float')
y = y.astype('float')
if soft:
# penalty
C = 1
alphas = fit_soft(x=X, y=y, C=C)
else:
alphas = fit(x=X, y=y)
# Refer to 统计学习方法-李航 算法7.2
self.w = np.sum(alphas*y[:, None]*X, axis=0)
self.b = y - np.dot(self.w, X.T)
self.b = sum(self.b)/self.b.size
# nomalization
norm = np.linalg.norm(self.w)
self.w = self.w/norm
self.b = self.b/norm
def predict(self, X):
y = np.sign(np.dot(self.w, X.T) + self.b*np.ones(X.shape[0]))
return self.change_label(y, replace={-1.:0.})