-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsvm_visualization.py
59 lines (47 loc) · 1.87 KB
/
svm_visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from visualization import Visualization
from sklearn.svm import LinearSVC
from vcd import VisualConceptDetection
import numpy as np
from datamanagers.CaltechManager import CaltechManager
from itertools import izip
import sys
import os
def get_image_title(prediction, real):
"""Returns a string that describes whether the prediction
is a true positive, false positive, etc. and with what
confidence the prediction is made.
Args:
prediction: List of predicted probabilities of
the respective classes.
real: List of corresponding correct labels.
"""
p = 1 if prediction > 0 else 0
result = ""
result += "True " if p == real else "False "
result += "positive" if p == 1 else "negative"
result += " - distance: %.5f" % prediction
return result
def get_svm_importances(coef):
"""Normalize the SVM weights."""
factor = 1.0 / np.linalg.norm(coef)
return (coef * factor).ravel()
if __name__ == "__main__":
svm = LinearSVC(C=0.1)
category = "Faces"
dataset = "all"
datamanager = CaltechManager()
datamanager.PATHS["RESULTS"] = os.path.join(datamanager.PATHS["BASE"], "results_Faces_LinearSVC_normalized")
vcd = VisualConceptDetection(svm, datamanager)
clf = vcd.load_object("Classifier", category)
importances = get_svm_importances(clf.coef_)
sample_matrix = vcd.datamanager.build_sample_matrix(dataset, category)
class_vector = vcd.datamanager.build_class_vector(dataset, category)
pred = clf.decision_function(sample_matrix)
del clf
image_titles = [get_image_title(prediction, real) for prediction, real in
izip(pred, class_vector)]
del class_vector
del sample_matrix
img_names = [f for f in vcd.datamanager.get_image_names(dataset, category)]
vis = Visualization(datamanager)
vis.visualize_images(img_names, importances, image_titles)