Skip to content

Commit 5d9cd09

Browse files
committed
add confusion matrix
1 parent eeda39b commit 5d9cd09

File tree

2 files changed

+148
-0
lines changed

2 files changed

+148
-0
lines changed

naive bayes text classifier.py

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
2+
import os
3+
import re
4+
import string
5+
import pickle
6+
import math
7+
from pandas import DataFrame
8+
path =r"C:\Users\Hasitha\Downloads\20_newsgroups"
9+
10+
# get the list of 20 news catagories....
11+
def get_cat_list():
12+
return os.listdir(path)
13+
14+
# open the files from each catagories
15+
def get_files_from_cat(cat):
16+
fileNameList = os.listdir(path+"\\"+cat)
17+
files=[]
18+
for a in fileNameList:
19+
file =open(path+"\\"+cat+"\\"+a)
20+
files.append(file)
21+
return files
22+
23+
# remove punctuations and clean the string
24+
def remove_punctuations_and_clean(a):
25+
a = a.replace('\r', '').replace('\n', ' ')
26+
for c in string.punctuation:
27+
a= a.replace(c," ")
28+
a=re.sub(' +',' ',a)
29+
return a
30+
31+
# to create vocabulary, get words in given catagory
32+
def list_words_in_cat(c):
33+
words=[]
34+
a= get_files_from_cat(c)
35+
for b in a:
36+
l=b.read()
37+
l=remove_punctuations_and_clean(l)
38+
words=words+l.split()
39+
return words
40+
41+
# split the given string to words and clean
42+
def get_word_list_from_string(string):
43+
words=[]
44+
l=remove_punctuations_and_clean(string)
45+
words=words+l.split()
46+
return words
47+
48+
# read all documents from all catagories and create word corpos
49+
def get_vocabulary():
50+
vocabulary={}
51+
for r in get_cat_list():
52+
w=list_words_in_cat(r)
53+
print(len(w))
54+
for a in w:
55+
vocabulary[a]=1
56+
return vocabulary
57+
58+
# create the p(o|H) values for all words
59+
def get_prob_Poh(vocabulary):
60+
cats = get_cat_list()
61+
p_word_given_cat={}
62+
for cat in cats:
63+
print(cat)
64+
word_list_from_cat= list_words_in_cat(cat)
65+
p_word_given_cat[cat]={}
66+
for word in vocabulary.keys():
67+
p_word_given_cat[cat][word]=1.0
68+
for word in word_list_from_cat:
69+
if word in vocabulary:
70+
p_word_given_cat[cat][word] += 1.0
71+
for word in vocabulary.keys():
72+
p_word_given_cat[cat][word]=(p_word_given_cat[cat][word])/(len(vocabulary)+len(word_list_from_cat));
73+
print(len(p_word_given_cat[cat]))
74+
return p_word_given_cat
75+
76+
# create the tast data set
77+
def list_strings_in_cat_test(cat,size):
78+
words_test={}
79+
a = get_files_from_cat(cat);
80+
counter =0
81+
for b in a[size:]:
82+
words_test[counter]=""
83+
l=b.read()
84+
words_test[counter]=l
85+
counter+=1
86+
return words_test
87+
88+
# testing the accuracy from test dataset
89+
def test_accuracy(size,prob_Poh,voc):
90+
cats=get_cat_list()
91+
overall=0
92+
length=0
93+
for cat in cats:
94+
test_data=list_strings_in_cat_test(cat,size)
95+
success_predicts=0
96+
for i in range (0,len(test_data)):
97+
a=classifier(test_data[i],prob_Poh,voc)
98+
if (a==cat):
99+
success_predicts+=1
100+
length+=len(test_data)
101+
overall+=success_predicts
102+
return (overall/(length))
103+
104+
# classifier
105+
def classifier(text_to_classify,prob_Poh,vocabulary):
106+
max_group=""
107+
max_prob=1
108+
for candidate_group in get_cat_list():
109+
p=math .log(1/20)
110+
for word in get_word_list_from_string(text_to_classify):
111+
if word in vocabulary:
112+
p=p+math .log(prob_Poh[candidate_group][word])
113+
if (p>max_prob or max_prob==1):
114+
max_prob=p
115+
max_group=candidate_group
116+
return max_group
117+
118+
# confusion matrix genarater
119+
def confusion_matrix(size,prob_Poh,voc):
120+
cats=get_cat_list()
121+
c_matrix={}
122+
for cat in cats:
123+
c_matrix[cat]={}
124+
for catInside in cats:
125+
c_matrix[cat][catInside]=0
126+
for cat in cats:
127+
test_data=list_strings_in_cat_test(cat,size)
128+
for i in range (0,len(test_data)):
129+
a=classifier(test_data[i],prob_Poh,voc)
130+
c_matrix[cat][a]+=1
131+
return c_matrix
132+
133+
# saving confusion matrix to a excel file
134+
def save_confusion_matrix_to_excel_file(c):
135+
df = DataFrame(c)
136+
df.to_excel('test.xlsx', sheet_name='sheet1', index=False)
137+
138+
139+
# ///////////// RUN
140+
141+
voc=get_vocabulary()
142+
prob_Poh=get_prob_Poh(voc)
143+
144+
print(test_accuracy(980,prob_Poh,voc))
145+
146+
c=confusion_matrix(900,prob_Poh,voc)
147+
save_confusion_matrix_to_excel_file(c)
148+

test.xlsx

6.52 KB
Binary file not shown.

0 commit comments

Comments
 (0)