-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDecisionTree.java
161 lines (147 loc) · 5.73 KB
/
DecisionTree.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import java.util.*;
public class DecisionTree {
private int features;
public DecisionTree(int features) {
this.features = features;
}
/**
* Returns a DTreeNode that holds the attribute name, index and a set of branches to children.
*
* @param pageViews Set of PageView Data to train from
* @param featNames Array of feature names correspoding to attributes
* @return A DTreeNode containing the attribute split on and branches to any children
*/
public DTreeNode learnTree(Set<Instance> instances, String[] featNames, Set<Integer> testAttr, double thresh) {
int positive = 0;
for (Instance instance : instances) {
if (instance.getLabel() == 1) { positive++; }
}
if (positive == instances.size()) {
return new DTreeNode(1);
} else if (positive == 0) {
return new DTreeNode(0);
} else if (testAttr.size() == this.features) {
return (positive >= instances.size() - positive) ? new DTreeNode(1) : new DTreeNode(0);
} else {
// Compute the attribute containing the maximum information gain.
int attrIndex = -1;
double maxGain = -Double.MAX_VALUE;
for (int i = 0; i < this.features; i++) {
if (!testAttr.contains(i)) {
double gain = informationGain(instances, i);
if (maxGain < gain) {
maxGain = gain;
attrIndex = i;
}
}
}
// Retrieve map of possible values mapped to their subsets.
// Determine Chi-Square.
Map<Integer, Set<Instance>> range = computeRange(instances, attrIndex);
Chi chi = new Chi();
if (chiSquare(positive, instances.size(), range) <= chi.critchi(thresh, range.keySet().size() - 1)) {
return (positive >= instances.size() - positive) ? new DTreeNode(1) : new DTreeNode(0);
}
// Create node for attribute we choose to split on.
// Remove attribute from available list.
int defaultLabel = (positive >= instances.size() - positive) ? 1 : 0;
DTreeNode node = new DTreeNode(featNames[attrIndex], attrIndex, defaultLabel, new HashMap<Integer, DTreeNode>());
testAttr.add(attrIndex);
// Recursive branching over all possible values for the attribute we are splitting on.
for (Integer value : range.keySet()) {
node.getBranches().put(value, learnTree(range.get(value), featNames, new HashSet<Integer>(testAttr), thresh));
}
return node;
}
}
/**
* Returns a double to compare against the threshhold.
*
* @param positive Total number of positives in the currect set to split on.
* @param total Total number of examples in the current set.
* @param range Map of attribute values to subset of examples for each.
* @return Returns a double indicating Chi-Square value to compare against.
*/
public double chiSquare(int positive, int total, Map<Integer, Set<Instance>> range) {
double sum = 0;
for (Integer value : range.keySet()) {
double pPrime = ((double) positive) * range.get(value).size() / total;
double nPrime = ((double) (total - positive)) * range.get(value).size() / total;
int pos = 0;
for (Instance instance : range.get(value)) {
if (instance.getLabel() == 1) { pos++; }
}
sum += (Math.pow(pPrime - pos, 2) / pPrime) + (Math.pow(nPrime - (range.get(value).size() - pos), 2) / nPrime);
}
return sum;
}
/**
* Computes the information gain for a particular attribute to split on.
*
* @param pageViews Accepts a set of pageview objects.
* @param attributeIndex Represents the index of the attribute we wish to split on.
* @return Returns the information gain for this particular attribute split.
*/
public double informationGain(Set<Instance> instances, int attributeIndex) {
double entropyS = entropy(instances);
double gain = 0;
Map<Integer, Set<Instance>> values = computeRange(instances, attributeIndex);
// Sum the individual entropies * the weighted fraction for that particular subset.
for (Integer value : values.keySet()) {
gain += values.get(value).size() / ((double) instances.size()) * entropy(values.get(value));
}
return entropyS - gain;
}
/**
* Returns a Mapping of values to subsets of PageViews corresponding to each value.
*
* @param pageViews Current set of available examples.
* @param attributeIndex Index of the attribute we wish to split on.
* @return Returns a Map of integer to Set<PageView>.
*/
public Map<Integer, Set<Instance>> computeRange(Set<Instance> instances, int attributeIndex) {
Map<Integer, Set<Instance>> values = new HashMap<Integer, Set<Instance>>();
for (Instance instance : instances) {
int value = instance.getFeatures()[attributeIndex];
if (!values.containsKey(value)) {
values.put(value, new HashSet<Instance>());
}
values.get(value).add(instance);
}
return values;
}
/**
* Calculates the entropy of a given collection
*
* @param pageViews Collection to calculate entropy of.
* @return Returns a double reprenting the entropy value.
*/
public double entropy(Set<Instance> instances) {
int tot = instances.size();
int pos = 0;
for (Instance p : instances) {
if (p.getLabel() == 1) { pos++; }
}
if (pos == 0 || pos == tot) { return 0; }
double pProp = (-1.0 * pos / tot) * Math.log(1.0 * pos / tot) / Math.log(2);
double nProp = (1.0 * (tot - pos) / tot) * Math.log(1.0 * (tot - pos) / tot) / Math.log(2);
return pProp - nProp;
}
/**
* Predicts the class values for a dataset of PageView objects,
*
* @param pageViews Set of PageView Data to train from
* @param root DTreeNode root for the decision tree built using the training data
* @return Returns the class value prediction
*/
public int predictTree(Instance instance, DTreeNode root) {
if (root.getBranches() == null) {
return root.getLabel();
}
int value = instance.getFeatures()[root.getIndex()];
if (root.getBranches().get(value) == null) {
return root.getDefaultLabel();
}
return predictTree(instance, root.getBranches().get(value));
}
}