-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprune.py
124 lines (112 loc) · 3.53 KB
/
prune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from collections import defaultdict
from queue import Queue
import numpy as np
from decision_tree import DecisionTree, ID3, C4_5, read_csv
def prune(tree, X, Y, header, alpha):
assert tree.root is not None
parents = dict() # node to the parent of node
q = Queue()
q.put(tree.root)
parents[tree.root] = None
st = list()
while not q.empty():
r = q.get()
st.append(r)
if not r.leaf:
for node in r.nodes.values():
assert node not in parents
parents[node] = r
q.put(node)
mheader = dict((name, i) for i, name in enumerate(header))
samples = defaultdict(lambda: defaultdict(int)) # node to the label distribution of node
losses = dict()
for x, y in zip(X, Y):
node = tree.root
while not node.leaf:
i = mheader[node.name]
node = node.nodes[x[i]]
assert node.leaf
samples[node][y] += 1
while st:
r = st.pop()
# update its parent
parent = parents[r]
samples_r = samples[r]
ps = samples[parent]
for k, v in samples_r.items():
# update samples
ps[k] += v
t = np.array([v for v in samples_r.values()])
n = sum(t)
fr = t / n
h = -(fr * np.log2(fr)).sum()
loss = n * h + alpha * 1
if r.leaf:
losses[r] = loss
else:
losses[r] = sum(losses[node] for node in r.nodes.values())
# try to prune
if loss <= losses[r]:
# prune
r.leaf = True
r.nodes = None
best_cls = None
best_cnt = None
for cls, cnt in samples_r.items():
if best_cnt is None or cnt > best_cnt:
best_cnt = cnt
best_cls = cls
r.leaf_cls = best_cls
losses[r] = loss
print(f"Prune Over, loss: {losses[tree.root]:.5}")
def compute_loss(tree, X, Y, header, alpha):
mheader = dict((name, i) for i, name in enumerate(header))
samples = defaultdict(lambda: defaultdict(int)) # node to the label distribution of node
losses = dict()
q = Queue()
q.put(tree.root)
for x, y in zip(X, Y):
node = tree.root
while not node.leaf:
i = mheader[node.name]
node = node.nodes[x[i]]
assert node.leaf
samples[node][y] += 1
loss = 0
while not q.empty():
r = q.get()
if not r.leaf:
for node in r.nodes.values():
q.put(node)
else:
# leaf
loss += alpha
samples_r = samples[r]
t = np.array([v for v in samples_r.values()])
n = sum(t)
fr = t / n
h = -(fr * np.log2(fr)).sum()
loss += n * h
return loss
if __name__ == '__main__':
fname = '../data/table5.1.csv'
columns, data = read_csv(fname)
use_id = True
offset = 0 if use_id else 1
alpha = 8
X, Y = data[:, offset:-1], data[:, -1]
header = columns[offset:-1]
# method = ID3
method = C4_5
dt = DecisionTree(method)
dt.train(X, Y, header)
loss = compute_loss(dt, X, Y, header, alpha)
print(f"Before prune, loss: {loss:.5}")
print(dt)
prune(dt, X, Y, header, alpha=alpha)
loss = compute_loss(dt, X, Y, header, alpha)
print(f"After prune, loss: {loss:.5}")
print(dt)
PY = dt.predict(X, header)
acc = (PY == Y).sum() / len(Y)
print(f"Accuracy: {acc}")