-
Notifications
You must be signed in to change notification settings - Fork 77
/
Copy pathutils.py
69 lines (53 loc) · 1.88 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gzip
from sklearn import svm
from sklearn.metrics import accuracy_score
import numpy as np
import theano
from keras.utils.data_utils import get_file
def load_data(data_file, url):
"""loads the data from the gzip pickled files, and converts to numpy arrays"""
print('loading data ...')
path = get_file(data_file, origin=url)
f = gzip.open(path, 'rb')
train_set, valid_set, test_set = load_pickle(f)
f.close()
train_set_x, train_set_y = make_numpy_array(train_set)
valid_set_x, valid_set_y = make_numpy_array(valid_set)
test_set_x, test_set_y = make_numpy_array(test_set)
return [(train_set_x, train_set_y), (valid_set_x, valid_set_y), (test_set_x, test_set_y)]
def make_numpy_array(data_xy):
"""converts the input to numpy arrays"""
data_x, data_y = data_xy
data_x = np.asarray(data_x, dtype=theano.config.floatX)
data_y = np.asarray(data_y, dtype='int32')
return data_x, data_y
def svm_classify(data, C):
"""
trains a linear SVM on the data
input C specifies the penalty factor of SVM
"""
train_data, _, train_label = data[0]
valid_data, _, valid_label = data[1]
test_data, _, test_label = data[2]
print('training SVM...')
clf = svm.LinearSVC(C=C, dual=False)
clf.fit(train_data, train_label.ravel())
p = clf.predict(test_data)
test_acc = accuracy_score(test_label, p)
p = clf.predict(valid_data)
valid_acc = accuracy_score(valid_label, p)
return [test_acc, valid_acc]
def load_pickle(f):
"""
loads and returns the content of a pickled file
it handles the inconsistencies between the pickle packages available in Python 2 and 3
"""
try:
import cPickle as thepickle
except ImportError:
import _pickle as thepickle
try:
ret = thepickle.load(f, encoding='latin1')
except TypeError:
ret = thepickle.load(f)
return ret