-
Notifications
You must be signed in to change notification settings - Fork 0
/
Decision_Tree_Titanic.py
148 lines (126 loc) · 4.56 KB
/
Decision_Tree_Titanic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import pandas as pd
import math
import operator
def calc_entropy(data_set):
#the data_set is a list
num_entries = len(data_set)
label_counts = {}
for feat_vec in data_set:
current_label = feat_vec[-1]
if current_label not in label_counts.keys():
label_counts[current_label] = 0
label_counts[current_label] += 1
entropy = 0.0
for key in label_counts.keys():
prob = float(label_counts[key]) / num_entries
entropy -= prob * math.log(prob, 2)
return entropy
def split_data_set(data_set, col, value):
#split the data with specific index equaling value and then remove this column
re_data_set = []
for feat_vec in data_set:
if feat_vec[col] == value:
reduced_feat_vec = feat_vec[:col]
reduced_feat_vec.extend(feat_vec[col+1:])
re_data_set.append(reduced_feat_vec)
return re_data_set
def choose_best_feat_to_split(data_set):
num_feat = len(data_set[0]) - 1
base_entropy = calc_entropy(data_set)
best_info_gain = 0.0
best_feat = -1
for i in range(num_feat):
feat_list = [example[i] for example in data_set]
unique_values = set(feat_list)
new_entropy = 0.0
for value in unique_values:
sub_data_set = split_data_set(data_set, i, value)
prob = len(sub_data_set) / float(len(data_set))
new_entropy += prob * calc_entropy(sub_data_set)
info_gain = base_entropy - new_entropy
if info_gain > best_info_gain:
best_info_gain = info_gain
best_feat = i
return best_feat
def majority(class_list):
#Count the majorty of this feature
class_count = {}
for vote in class_list:
if vote not in class_count.keys():
class_count[vote] = 0
class_count[vote] += 1
sorted_class_count = sorted(class_count.iteritems(), key=operator.itemgetter(1), reverse=True)
return sorted_class_count[0][0]
def create_tree(data_set, labels):
class_list = [example[-1] for example in data_set]
if class_list.count(class_list[0]) == len(class_list):
#stop splitting if all the classes are equal
return class_list[0]
if len(data_set[0]) == 1:
#stop splitting if there is no features in data_set
return majority(class_list)
best_feat = choose_best_feat_to_split(data_set)
best_feat_label = labels[best_feat]
tree = {best_feat_label:{}}
del(labels[best_feat])
feat_values = [example[best_feat] for example in data_set]
unique_values = set(feat_values)
for value in unique_values:
sub_labels = labels[:]
tree[best_feat_label][value] = create_tree(split_data_set(data_set, best_feat, value), sub_labels)
return tree
def classify(tree, labels, data):
root = tree.keys()[0]
left_tree = tree[root]
feat_index = labels.index(root)
key = data[feat_index]
value_of_feat = left_tree[key]
if isinstance(value_of_feat, dict):
class_label = classify(value_of_feat, labels, data)
else:
class_label = value_of_feat
return class_label
def classify_all(tree, labels, data_set):
class_labels = []
i = 0
for data in data_set:
class_label = classify(tree, labels, data)
class_labels.append(class_label)
#print i
i += 1
return class_labels
def get_dataset(file_path, is_train):
df = pd.read_csv(file_path)
if is_train:
df = df[['Sex', 'SibSp', 'Survived']]
labels = list(df.columns)
labels.pop(-1)
data_set = df.values.tolist()
else:
df = df[['Sex', 'SibSp']]
labels = list(df.columns)
data_set = df.values.tolist()
return data_set, labels
def get_passenger_id(path):
df = pd.read_csv(path)
df = df[['PassengerId']]
df_to_list = df.values.tolist()
return df_to_list
def get_result(train_file_path, test_file_path):
train_data, train_labels = get_dataset(train_file_path, True)
test_data, test_labels = get_dataset(test_file_path, False)
tree = create_tree(train_data, train_labels)
print tree
class_labels = classify_all(tree, test_labels, test_data)
print class_labels
passenger_id_list = get_passenger_id(test_file_path)
f = open('result.csv', 'w')
f.write('PassengerId,Survived\n')
i = 0
while i < len(passenger_id_list):
f.write(str(passenger_id_list[i][0]) + ',' + str(class_labels[i]) + '\n')
i += 1
f.close()
if __name__=='__main__':
#It is from kaggle Titanic project
get_result('./data/train.csv', './data/test.csv')